pszemraj commited on
Commit
76b1322
1 Parent(s): 304970e
Files changed (1) hide show
  1. modeling_mpt.py +274 -80
modeling_mpt.py CHANGED
@@ -9,63 +9,90 @@ import torch
9
  import torch.nn as nn
10
  import torch.nn.functional as F
11
  from transformers import PreTrainedModel, PreTrainedTokenizer, PreTrainedTokenizerFast
12
- from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
 
 
 
13
  from .attention import attn_bias_shape, build_attn_bias
14
  from .blocks import MPTBlock
15
  from .norm import NORM_CLASS_REGISTRY
16
  from .configuration_mpt import MPTConfig
17
  from .adapt_tokenizer import AutoTokenizerForMOD, adapt_tokenizer_for_denoising
18
- from .hf_prefixlm_converter import add_bidirectional_mask_if_missing, convert_hf_causal_lm_to_prefix_lm
 
 
 
19
  from .meta_init_context import init_empty_weights
20
  from .param_init_fns import MODEL_INIT_REGISTRY, generic_param_init_fn_
 
21
  Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
22
 
 
23
  class MPTPreTrainedModel(PreTrainedModel):
24
  config_class = MPTConfig
25
- base_model_prefix = 'model'
26
  supports_gradient_checkpointing = True
27
  _no_split_modules = []
28
-
29
- class MPTModel(MPTPreTrainedModel):
30
 
 
 
31
  def __init__(self, config: MPTConfig):
32
  config._validate_config()
33
  super().__init__(config)
34
- self.attn_impl = config.attn_config['attn_impl']
35
- self.prefix_lm = config.attn_config['prefix_lm']
36
- self.attn_uses_sequence_id = config.attn_config['attn_uses_sequence_id']
37
- self.alibi = config.attn_config['alibi']
38
- self.alibi_bias_max = config.attn_config['alibi_bias_max']
39
  if config.norm_type.lower() not in NORM_CLASS_REGISTRY.keys():
40
- norm_options = ' | '.join(NORM_CLASS_REGISTRY.keys())
41
- raise NotImplementedError(f'Requested norm type ({config.norm_type}) is not implemented within this repo (Options: {norm_options}).')
 
 
42
  norm_class = NORM_CLASS_REGISTRY[config.norm_type.lower()]
43
  self.embedding_fraction = config.embedding_fraction
44
- self.wte = nn.Embedding(config.vocab_size, config.d_model, device=config.init_device)
 
 
45
  if not self.alibi:
46
- self.wpe = nn.Embedding(config.max_seq_len, config.d_model, device=config.init_device)
 
 
47
  self.emb_drop = nn.Dropout(config.emb_pdrop)
48
- self.blocks = nn.ModuleList([MPTBlock(device=config.init_device, **config.to_dict()) for _ in range(config.n_layers)])
 
 
 
 
 
49
  self.norm_f = norm_class(config.d_model, device=config.init_device)
50
- if config.init_device != 'meta':
51
  self.apply(self.param_init_fn)
52
  self.is_causal = not self.prefix_lm
53
  self._attn_bias_initialized = False
54
  self.attn_bias = None
55
- self.attn_bias_shape = attn_bias_shape(self.attn_impl, config.n_heads, config.max_seq_len, self.alibi, prefix_lm=self.prefix_lm, causal=self.is_causal, use_sequence_id=self.attn_uses_sequence_id)
 
 
 
 
 
 
 
 
56
  if config.no_bias:
57
  for module in self.modules():
58
- if hasattr(module, 'bias') and isinstance(module.bias, nn.Parameter):
59
  if config.verbose:
60
- warnings.warn(f'Removing bias ({module.bias}) from {module}.')
61
- module.register_parameter('bias', None)
62
  if config.verbose and config.verbose > 2:
63
  print(self)
64
- if 'verbose' not in self.config.init_config:
65
- self.config.init_config['verbose'] = self.config.verbose
66
- if self.config.init_config['verbose'] > 1:
67
- init_fn_name = self.config.init_config['name']
68
- warnings.warn(f'Using {init_fn_name} initialization.')
69
 
70
  def get_input_embeddings(self):
71
  return self.wte
@@ -74,13 +101,30 @@ class MPTModel(MPTPreTrainedModel):
74
  self.wte = value
75
 
76
  @torch.no_grad()
77
- def _attn_bias(self, device, dtype, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None):
 
 
 
 
 
 
 
78
  if not self._attn_bias_initialized:
79
  if self.attn_bias_shape:
80
- self.attn_bias = torch.zeros(self.attn_bias_shape, device=device, dtype=dtype)
81
- self.attn_bias = build_attn_bias(self.attn_impl, self.attn_bias, self.config.n_heads, self.config.max_seq_len, causal=self.is_causal, alibi=self.alibi, alibi_bias_max=self.alibi_bias_max)
 
 
 
 
 
 
 
 
 
 
82
  self._attn_bias_initialized = True
83
- if self.attn_impl == 'flash':
84
  return (self.attn_bias, attention_mask)
85
  if self.attn_bias is not None:
86
  self.attn_bias = self.attn_bias.to(dtype=dtype, device=device)
@@ -99,58 +143,110 @@ class MPTModel(MPTPreTrainedModel):
99
  else:
100
  attn_bias = attn_bias[:, :, :, -s_k:]
101
  if prefix_mask is not None and attention_mask.shape != prefix_mask.shape:
102
- raise ValueError(f'attention_mask shape={attention_mask.shape} ' + f'and prefix_mask shape={prefix_mask.shape} are not equal.')
 
 
 
103
  min_val = torch.finfo(attn_bias.dtype).min
104
- attn_bias = attn_bias.masked_fill(~attention_mask.view(-1, 1, 1, s_k), min_val)
 
 
105
  return (attn_bias, None)
106
 
107
  def _apply_prefix_mask(self, attn_bias: torch.Tensor, prefix_mask: torch.Tensor):
108
  (s_k, s_q) = attn_bias.shape[-2:]
109
  if s_k != self.config.max_seq_len or s_q != self.config.max_seq_len:
110
- raise ValueError('attn_bias does not match the expected shape. ' + f'The last two dimensions should both be {self.config.max_length} ' + f'but are {s_k} and {s_q}.')
 
 
 
 
111
  seq_len = prefix_mask.shape[-1]
112
  if seq_len > self.config.max_seq_len:
113
- raise ValueError(f'prefix_mask sequence length cannot exceed max_seq_len={self.config.max_seq_len}')
 
 
114
  attn_bias = attn_bias[..., :seq_len, :seq_len]
115
- causal = torch.tril(torch.ones((seq_len, seq_len), dtype=torch.bool, device=prefix_mask.device)).view(1, 1, seq_len, seq_len)
 
 
116
  prefix = prefix_mask.view(-1, 1, 1, seq_len)
117
  cannot_attend = ~torch.logical_or(causal, prefix.bool())
118
  min_val = torch.finfo(attn_bias.dtype).min
119
  attn_bias = attn_bias.masked_fill(cannot_attend, min_val)
120
  return attn_bias
121
 
122
- def _apply_sequence_id(self, attn_bias: torch.Tensor, sequence_id: torch.LongTensor):
 
 
123
  seq_len = sequence_id.shape[-1]
124
  if seq_len > self.config.max_seq_len:
125
- raise ValueError(f'sequence_id sequence length cannot exceed max_seq_len={self.config.max_seq_len}')
 
 
126
  attn_bias = attn_bias[..., :seq_len, :seq_len]
127
- cannot_attend = torch.logical_not(torch.eq(sequence_id.view(-1, seq_len, 1), sequence_id.view(-1, 1, seq_len))).unsqueeze(1)
 
 
128
  min_val = torch.finfo(attn_bias.dtype).min
129
  attn_bias = attn_bias.masked_fill(cannot_attend, min_val)
130
  return attn_bias
131
 
132
- def forward(self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None):
133
- return_dict = return_dict if return_dict is not None else self.config.return_dict
 
 
 
 
 
 
 
 
 
 
 
 
 
134
  use_cache = use_cache if use_cache is not None else self.config.use_cache
135
  if attention_mask is not None:
136
  attention_mask = attention_mask.bool()
137
  if prefix_mask is not None:
138
  prefix_mask = prefix_mask.bool()
139
  if not return_dict:
140
- raise NotImplementedError('return_dict False is not implemented yet for MPT')
 
 
141
  if output_attentions:
142
- raise NotImplementedError('output_attentions is not implemented yet for MPT')
143
- if attention_mask is not None and attention_mask[:, 0].sum() != attention_mask.shape[0] and self.training:
144
- raise NotImplementedError('MPT does not support training with left padding.')
 
 
 
 
 
 
 
 
145
  if self.prefix_lm and prefix_mask is None:
146
- raise ValueError('prefix_mask is a required argument when MPT is configured with prefix_lm=True.')
 
 
147
  if self.training:
148
  if self.attn_uses_sequence_id and sequence_id is None:
149
- raise ValueError('sequence_id is a required argument when MPT is configured with attn_uses_sequence_id=True ' + 'and the model is in train mode.')
 
 
 
150
  elif self.attn_uses_sequence_id is False and sequence_id is not None:
151
- warnings.warn('MPT received non-None input for `sequence_id` but is configured with attn_uses_sequence_id=False. ' + 'This input will be ignored. If you want the model to use `sequence_id`, set attn_uses_sequence_id to True.')
 
 
 
152
  S = input_ids.size(1)
153
- assert S <= self.config.max_seq_len, f'Cannot forward input with seq_len={S}, this model only supports seq_len<={self.config.max_seq_len}'
 
 
154
  tok_emb = self.wte(input_ids)
155
  if self.alibi:
156
  x = tok_emb
@@ -158,39 +254,80 @@ class MPTModel(MPTPreTrainedModel):
158
  past_position = 0
159
  if past_key_values is not None:
160
  if len(past_key_values) != self.config.n_layers:
161
- raise ValueError(f'past_key_values must provide a past_key_value for each attention ' + f'layer in the network (len(past_key_values)={len(past_key_values)!r}; self.config.n_layers={self.config.n_layers!r}).')
 
 
 
162
  past_position = past_key_values[0][0].size(1)
163
  if S + past_position > self.config.max_seq_len:
164
- raise ValueError(f'Cannot forward input with past sequence length {past_position} and current sequence length {S + 1}, this model only supports total sequence length <= {self.config.max_seq_len}.')
165
- pos = torch.arange(past_position, S + past_position, dtype=torch.long, device=input_ids.device).unsqueeze(0)
 
 
 
 
 
 
 
166
  if attention_mask is not None:
167
- pos = torch.clamp(pos - torch.cumsum((~attention_mask).to(torch.int32), dim=1)[:, past_position:], min=0)
 
 
 
 
 
 
168
  pos_emb = self.wpe(pos)
169
  x = tok_emb + pos_emb
170
  if self.embedding_fraction == 1:
171
  x = self.emb_drop(x)
172
  else:
173
- x_shrunk = x * self.embedding_fraction + x.detach() * (1 - self.embedding_fraction)
 
 
174
  assert isinstance(self.emb_drop, nn.Module)
175
  x = self.emb_drop(x_shrunk)
176
- (attn_bias, attention_mask) = self._attn_bias(device=x.device, dtype=x.dtype, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id)
 
 
 
 
 
 
177
  if use_cache and past_key_values is None:
178
  past_key_values = [() for _ in range(self.config.n_layers)]
179
  all_hidden_states = () if output_hidden_states else None
180
- for (b_idx, block) in enumerate(self.blocks):
181
  if output_hidden_states:
182
  assert all_hidden_states is not None
183
  all_hidden_states = all_hidden_states + (x,)
184
- past_key_value = past_key_values[b_idx] if past_key_values is not None else None
185
- (x, past_key_value) = block(x, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=self.is_causal)
 
 
 
 
 
 
 
 
186
  if past_key_values is not None:
187
  past_key_values[b_idx] = past_key_value
188
  x = self.norm_f(x)
189
- return BaseModelOutputWithPast(last_hidden_state=x, past_key_values=past_key_values, hidden_states=all_hidden_states)
 
 
 
 
190
 
191
  def param_init_fn(self, module):
192
- init_fn_name = self.config.init_config['name']
193
- MODEL_INIT_REGISTRY[init_fn_name](module=module, n_layers=self.config.n_layers, d_model=self.config.d_model, **self.config.init_config)
 
 
 
 
 
194
 
195
  def fsdp_wrap_fn(self, module):
196
  return isinstance(module, MPTBlock)
@@ -198,21 +335,23 @@ class MPTModel(MPTPreTrainedModel):
198
  def activation_checkpointing_fn(self, module):
199
  return isinstance(module, MPTBlock)
200
 
201
- class MPTForCausalLM(MPTPreTrainedModel):
202
 
 
203
  def __init__(self, config: MPTConfig):
204
  super().__init__(config)
205
  if not config.tie_word_embeddings:
206
- raise ValueError('MPTForCausalLM only supports tied word embeddings')
207
  self.transformer = MPTModel(config)
208
  self.logit_scale = None
209
  if config.logit_scale is not None:
210
  logit_scale = config.logit_scale
211
  if isinstance(logit_scale, str):
212
- if logit_scale == 'inv_sqrt_d_model':
213
  logit_scale = 1 / math.sqrt(config.d_model)
214
  else:
215
- raise ValueError(f"logit_scale={logit_scale!r} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'.")
 
 
216
  self.logit_scale = logit_scale
217
 
218
  def get_input_embeddings(self):
@@ -233,25 +372,63 @@ class MPTForCausalLM(MPTPreTrainedModel):
233
  def get_decoder(self):
234
  return self.transformer
235
 
236
- def forward(self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, labels: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None):
237
- return_dict = return_dict if return_dict is not None else self.config.return_dict
 
 
 
 
 
 
 
 
 
 
 
 
 
 
238
  use_cache = use_cache if use_cache is not None else self.config.use_cache
239
- outputs = self.transformer(input_ids=input_ids, past_key_values=past_key_values, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id, return_dict=return_dict, output_attentions=output_attentions, output_hidden_states=output_hidden_states, use_cache=use_cache)
 
 
 
 
 
 
 
 
 
 
240
  logits = F.linear(outputs.last_hidden_state, self.transformer.wte.weight)
241
  if self.logit_scale is not None:
242
  if self.logit_scale == 0:
243
- warnings.warn(f'Multiplying logits by self.logit_scale={self.logit_scale!r}. This will produce uniform (uninformative) outputs.')
 
 
244
  logits *= self.logit_scale
245
  loss = None
246
  if labels is not None:
247
  labels = torch.roll(labels, shifts=-1)
248
  labels[:, -1] = -100
249
- loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.to(logits.device).view(-1))
250
- return CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states)
 
 
 
 
 
 
 
251
 
252
  def param_init_fn(self, module):
253
- init_fn_name = self.config.init_config['name']
254
- MODEL_INIT_REGISTRY[init_fn_name](module=module, n_layers=self.config.n_layers, d_model=self.config.d_model, **self.config.init_config)
 
 
 
 
 
255
 
256
  def fsdp_wrap_fn(self, module):
257
  return isinstance(module, MPTBlock)
@@ -259,12 +436,16 @@ class MPTForCausalLM(MPTPreTrainedModel):
259
  def activation_checkpointing_fn(self, module):
260
  return isinstance(module, MPTBlock)
261
 
262
- def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
 
 
263
  if inputs_embeds is not None:
264
- raise NotImplementedError('inputs_embeds is not implemented for MPT yet')
265
- attention_mask = kwargs['attention_mask'].bool()
266
  if attention_mask[:, -1].sum() != attention_mask.shape[0]:
267
- raise NotImplementedError('MPT does not support generation with right padding.')
 
 
268
  if self.transformer.attn_uses_sequence_id and self.training:
269
  sequence_id = torch.zeros_like(input_ids[:1])
270
  else:
@@ -273,11 +454,20 @@ class MPTForCausalLM(MPTPreTrainedModel):
273
  input_ids = input_ids[:, -1].unsqueeze(-1)
274
  if self.transformer.prefix_lm:
275
  prefix_mask = torch.ones_like(attention_mask)
276
- if kwargs.get('use_cache') == False:
277
- raise NotImplementedError('MPT with prefix_lm=True does not support use_cache=False.')
 
 
278
  else:
279
  prefix_mask = None
280
- return {'input_ids': input_ids, 'attention_mask': attention_mask, 'prefix_mask': prefix_mask, 'sequence_id': sequence_id, 'past_key_values': past_key_values, 'use_cache': kwargs.get('use_cache', True)}
 
 
 
 
 
 
 
281
 
282
  @staticmethod
283
  def _reorder_cache(past_key_values, beam_idx):
@@ -288,5 +478,9 @@ class MPTForCausalLM(MPTPreTrainedModel):
288
  """
289
  reordered_past = []
290
  for layer_past in past_key_values:
291
- reordered_past += [tuple((past_state.index_select(0, beam_idx) for past_state in layer_past))]
292
- return reordered_past
 
 
 
 
 
9
  import torch.nn as nn
10
  import torch.nn.functional as F
11
  from transformers import PreTrainedModel, PreTrainedTokenizer, PreTrainedTokenizerFast
12
+ from transformers.modeling_outputs import (
13
+ BaseModelOutputWithPast,
14
+ CausalLMOutputWithPast,
15
+ )
16
  from .attention import attn_bias_shape, build_attn_bias
17
  from .blocks import MPTBlock
18
  from .norm import NORM_CLASS_REGISTRY
19
  from .configuration_mpt import MPTConfig
20
  from .adapt_tokenizer import AutoTokenizerForMOD, adapt_tokenizer_for_denoising
21
+ from .hf_prefixlm_converter import (
22
+ add_bidirectional_mask_if_missing,
23
+ convert_hf_causal_lm_to_prefix_lm,
24
+ )
25
  from .meta_init_context import init_empty_weights
26
  from .param_init_fns import MODEL_INIT_REGISTRY, generic_param_init_fn_
27
+
28
  Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
29
 
30
+
31
  class MPTPreTrainedModel(PreTrainedModel):
32
  config_class = MPTConfig
33
+ base_model_prefix = "model"
34
  supports_gradient_checkpointing = True
35
  _no_split_modules = []
 
 
36
 
37
+
38
+ class MPTModel(MPTPreTrainedModel):
39
  def __init__(self, config: MPTConfig):
40
  config._validate_config()
41
  super().__init__(config)
42
+ self.attn_impl = config.attn_config["attn_impl"]
43
+ self.prefix_lm = config.attn_config["prefix_lm"]
44
+ self.attn_uses_sequence_id = config.attn_config["attn_uses_sequence_id"]
45
+ self.alibi = config.attn_config["alibi"]
46
+ self.alibi_bias_max = config.attn_config["alibi_bias_max"]
47
  if config.norm_type.lower() not in NORM_CLASS_REGISTRY.keys():
48
+ norm_options = " | ".join(NORM_CLASS_REGISTRY.keys())
49
+ raise NotImplementedError(
50
+ f"Requested norm type ({config.norm_type}) is not implemented within this repo (Options: {norm_options})."
51
+ )
52
  norm_class = NORM_CLASS_REGISTRY[config.norm_type.lower()]
53
  self.embedding_fraction = config.embedding_fraction
54
+ self.wte = nn.Embedding(
55
+ config.vocab_size, config.d_model, device=config.init_device
56
+ )
57
  if not self.alibi:
58
+ self.wpe = nn.Embedding(
59
+ config.max_seq_len, config.d_model, device=config.init_device
60
+ )
61
  self.emb_drop = nn.Dropout(config.emb_pdrop)
62
+ self.blocks = nn.ModuleList(
63
+ [
64
+ MPTBlock(device=config.init_device, **config.to_dict())
65
+ for _ in range(config.n_layers)
66
+ ]
67
+ )
68
  self.norm_f = norm_class(config.d_model, device=config.init_device)
69
+ if config.init_device != "meta":
70
  self.apply(self.param_init_fn)
71
  self.is_causal = not self.prefix_lm
72
  self._attn_bias_initialized = False
73
  self.attn_bias = None
74
+ self.attn_bias_shape = attn_bias_shape(
75
+ self.attn_impl,
76
+ config.n_heads,
77
+ config.max_seq_len,
78
+ self.alibi,
79
+ prefix_lm=self.prefix_lm,
80
+ causal=self.is_causal,
81
+ use_sequence_id=self.attn_uses_sequence_id,
82
+ )
83
  if config.no_bias:
84
  for module in self.modules():
85
+ if hasattr(module, "bias") and isinstance(module.bias, nn.Parameter):
86
  if config.verbose:
87
+ warnings.warn(f"Removing bias ({module.bias}) from {module}.")
88
+ module.register_parameter("bias", None)
89
  if config.verbose and config.verbose > 2:
90
  print(self)
91
+ if "verbose" not in self.config.init_config:
92
+ self.config.init_config["verbose"] = self.config.verbose
93
+ if self.config.init_config["verbose"] > 1:
94
+ init_fn_name = self.config.init_config["name"]
95
+ warnings.warn(f"Using {init_fn_name} initialization.")
96
 
97
  def get_input_embeddings(self):
98
  return self.wte
 
101
  self.wte = value
102
 
103
  @torch.no_grad()
104
+ def _attn_bias(
105
+ self,
106
+ device,
107
+ dtype,
108
+ attention_mask: Optional[torch.ByteTensor] = None,
109
+ prefix_mask: Optional[torch.ByteTensor] = None,
110
+ sequence_id: Optional[torch.LongTensor] = None,
111
+ ):
112
  if not self._attn_bias_initialized:
113
  if self.attn_bias_shape:
114
+ self.attn_bias = torch.zeros(
115
+ self.attn_bias_shape, device=device, dtype=dtype
116
+ )
117
+ self.attn_bias = build_attn_bias(
118
+ self.attn_impl,
119
+ self.attn_bias,
120
+ self.config.n_heads,
121
+ self.config.max_seq_len,
122
+ causal=self.is_causal,
123
+ alibi=self.alibi,
124
+ alibi_bias_max=self.alibi_bias_max,
125
+ )
126
  self._attn_bias_initialized = True
127
+ if self.attn_impl == "flash":
128
  return (self.attn_bias, attention_mask)
129
  if self.attn_bias is not None:
130
  self.attn_bias = self.attn_bias.to(dtype=dtype, device=device)
 
143
  else:
144
  attn_bias = attn_bias[:, :, :, -s_k:]
145
  if prefix_mask is not None and attention_mask.shape != prefix_mask.shape:
146
+ raise ValueError(
147
+ f"attention_mask shape={attention_mask.shape} "
148
+ + f"and prefix_mask shape={prefix_mask.shape} are not equal."
149
+ )
150
  min_val = torch.finfo(attn_bias.dtype).min
151
+ attn_bias = attn_bias.masked_fill(
152
+ ~attention_mask.view(-1, 1, 1, s_k), min_val
153
+ )
154
  return (attn_bias, None)
155
 
156
  def _apply_prefix_mask(self, attn_bias: torch.Tensor, prefix_mask: torch.Tensor):
157
  (s_k, s_q) = attn_bias.shape[-2:]
158
  if s_k != self.config.max_seq_len or s_q != self.config.max_seq_len:
159
+ raise ValueError(
160
+ "attn_bias does not match the expected shape. "
161
+ + f"The last two dimensions should both be {self.config.max_length} "
162
+ + f"but are {s_k} and {s_q}."
163
+ )
164
  seq_len = prefix_mask.shape[-1]
165
  if seq_len > self.config.max_seq_len:
166
+ raise ValueError(
167
+ f"prefix_mask sequence length cannot exceed max_seq_len={self.config.max_seq_len}"
168
+ )
169
  attn_bias = attn_bias[..., :seq_len, :seq_len]
170
+ causal = torch.tril(
171
+ torch.ones((seq_len, seq_len), dtype=torch.bool, device=prefix_mask.device)
172
+ ).view(1, 1, seq_len, seq_len)
173
  prefix = prefix_mask.view(-1, 1, 1, seq_len)
174
  cannot_attend = ~torch.logical_or(causal, prefix.bool())
175
  min_val = torch.finfo(attn_bias.dtype).min
176
  attn_bias = attn_bias.masked_fill(cannot_attend, min_val)
177
  return attn_bias
178
 
179
+ def _apply_sequence_id(
180
+ self, attn_bias: torch.Tensor, sequence_id: torch.LongTensor
181
+ ):
182
  seq_len = sequence_id.shape[-1]
183
  if seq_len > self.config.max_seq_len:
184
+ raise ValueError(
185
+ f"sequence_id sequence length cannot exceed max_seq_len={self.config.max_seq_len}"
186
+ )
187
  attn_bias = attn_bias[..., :seq_len, :seq_len]
188
+ cannot_attend = torch.logical_not(
189
+ torch.eq(sequence_id.view(-1, seq_len, 1), sequence_id.view(-1, 1, seq_len))
190
+ ).unsqueeze(1)
191
  min_val = torch.finfo(attn_bias.dtype).min
192
  attn_bias = attn_bias.masked_fill(cannot_attend, min_val)
193
  return attn_bias
194
 
195
+ def forward(
196
+ self,
197
+ input_ids: torch.LongTensor,
198
+ past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None,
199
+ attention_mask: Optional[torch.ByteTensor] = None,
200
+ prefix_mask: Optional[torch.ByteTensor] = None,
201
+ sequence_id: Optional[torch.LongTensor] = None,
202
+ return_dict: Optional[bool] = None,
203
+ output_attentions: Optional[bool] = None,
204
+ output_hidden_states: Optional[bool] = None,
205
+ use_cache: Optional[bool] = None,
206
+ ):
207
+ return_dict = (
208
+ return_dict if return_dict is not None else self.config.return_dict
209
+ )
210
  use_cache = use_cache if use_cache is not None else self.config.use_cache
211
  if attention_mask is not None:
212
  attention_mask = attention_mask.bool()
213
  if prefix_mask is not None:
214
  prefix_mask = prefix_mask.bool()
215
  if not return_dict:
216
+ raise NotImplementedError(
217
+ "return_dict False is not implemented yet for MPT"
218
+ )
219
  if output_attentions:
220
+ raise NotImplementedError(
221
+ "output_attentions is not implemented yet for MPT"
222
+ )
223
+ if (
224
+ attention_mask is not None
225
+ and attention_mask[:, 0].sum() != attention_mask.shape[0]
226
+ and self.training
227
+ ):
228
+ raise NotImplementedError(
229
+ "MPT does not support training with left padding."
230
+ )
231
  if self.prefix_lm and prefix_mask is None:
232
+ raise ValueError(
233
+ "prefix_mask is a required argument when MPT is configured with prefix_lm=True."
234
+ )
235
  if self.training:
236
  if self.attn_uses_sequence_id and sequence_id is None:
237
+ raise ValueError(
238
+ "sequence_id is a required argument when MPT is configured with attn_uses_sequence_id=True "
239
+ + "and the model is in train mode."
240
+ )
241
  elif self.attn_uses_sequence_id is False and sequence_id is not None:
242
+ warnings.warn(
243
+ "MPT received non-None input for `sequence_id` but is configured with attn_uses_sequence_id=False. "
244
+ + "This input will be ignored. If you want the model to use `sequence_id`, set attn_uses_sequence_id to True."
245
+ )
246
  S = input_ids.size(1)
247
+ assert (
248
+ S <= self.config.max_seq_len
249
+ ), f"Cannot forward input with seq_len={S}, this model only supports seq_len<={self.config.max_seq_len}"
250
  tok_emb = self.wte(input_ids)
251
  if self.alibi:
252
  x = tok_emb
 
254
  past_position = 0
255
  if past_key_values is not None:
256
  if len(past_key_values) != self.config.n_layers:
257
+ raise ValueError(
258
+ f"past_key_values must provide a past_key_value for each attention "
259
+ + f"layer in the network (len(past_key_values)={len(past_key_values)!r}; self.config.n_layers={self.config.n_layers!r})."
260
+ )
261
  past_position = past_key_values[0][0].size(1)
262
  if S + past_position > self.config.max_seq_len:
263
+ raise ValueError(
264
+ f"Cannot forward input with past sequence length {past_position} and current sequence length {S + 1}, this model only supports total sequence length <= {self.config.max_seq_len}."
265
+ )
266
+ pos = torch.arange(
267
+ past_position,
268
+ S + past_position,
269
+ dtype=torch.long,
270
+ device=input_ids.device,
271
+ ).unsqueeze(0)
272
  if attention_mask is not None:
273
+ pos = torch.clamp(
274
+ pos
275
+ - torch.cumsum((~attention_mask).to(torch.int32), dim=1)[
276
+ :, past_position:
277
+ ],
278
+ min=0,
279
+ )
280
  pos_emb = self.wpe(pos)
281
  x = tok_emb + pos_emb
282
  if self.embedding_fraction == 1:
283
  x = self.emb_drop(x)
284
  else:
285
+ x_shrunk = x * self.embedding_fraction + x.detach() * (
286
+ 1 - self.embedding_fraction
287
+ )
288
  assert isinstance(self.emb_drop, nn.Module)
289
  x = self.emb_drop(x_shrunk)
290
+ (attn_bias, attention_mask) = self._attn_bias(
291
+ device=x.device,
292
+ dtype=x.dtype,
293
+ attention_mask=attention_mask,
294
+ prefix_mask=prefix_mask,
295
+ sequence_id=sequence_id,
296
+ )
297
  if use_cache and past_key_values is None:
298
  past_key_values = [() for _ in range(self.config.n_layers)]
299
  all_hidden_states = () if output_hidden_states else None
300
+ for b_idx, block in enumerate(self.blocks):
301
  if output_hidden_states:
302
  assert all_hidden_states is not None
303
  all_hidden_states = all_hidden_states + (x,)
304
+ past_key_value = (
305
+ past_key_values[b_idx] if past_key_values is not None else None
306
+ )
307
+ (x, past_key_value) = block(
308
+ x,
309
+ past_key_value=past_key_value,
310
+ attn_bias=attn_bias,
311
+ attention_mask=attention_mask,
312
+ is_causal=self.is_causal,
313
+ )
314
  if past_key_values is not None:
315
  past_key_values[b_idx] = past_key_value
316
  x = self.norm_f(x)
317
+ return BaseModelOutputWithPast(
318
+ last_hidden_state=x,
319
+ past_key_values=past_key_values,
320
+ hidden_states=all_hidden_states,
321
+ )
322
 
323
  def param_init_fn(self, module):
324
+ init_fn_name = self.config.init_config["name"]
325
+ MODEL_INIT_REGISTRY[init_fn_name](
326
+ module=module,
327
+ n_layers=self.config.n_layers,
328
+ d_model=self.config.d_model,
329
+ **self.config.init_config,
330
+ )
331
 
332
  def fsdp_wrap_fn(self, module):
333
  return isinstance(module, MPTBlock)
 
335
  def activation_checkpointing_fn(self, module):
336
  return isinstance(module, MPTBlock)
337
 
 
338
 
339
+ class MPTForCausalLM(MPTPreTrainedModel):
340
  def __init__(self, config: MPTConfig):
341
  super().__init__(config)
342
  if not config.tie_word_embeddings:
343
+ raise ValueError("MPTForCausalLM only supports tied word embeddings")
344
  self.transformer = MPTModel(config)
345
  self.logit_scale = None
346
  if config.logit_scale is not None:
347
  logit_scale = config.logit_scale
348
  if isinstance(logit_scale, str):
349
+ if logit_scale == "inv_sqrt_d_model":
350
  logit_scale = 1 / math.sqrt(config.d_model)
351
  else:
352
+ raise ValueError(
353
+ f"logit_scale={logit_scale!r} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'."
354
+ )
355
  self.logit_scale = logit_scale
356
 
357
  def get_input_embeddings(self):
 
372
  def get_decoder(self):
373
  return self.transformer
374
 
375
+ def forward(
376
+ self,
377
+ input_ids: torch.LongTensor,
378
+ past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None,
379
+ attention_mask: Optional[torch.ByteTensor] = None,
380
+ prefix_mask: Optional[torch.ByteTensor] = None,
381
+ sequence_id: Optional[torch.LongTensor] = None,
382
+ labels: Optional[torch.LongTensor] = None,
383
+ return_dict: Optional[bool] = None,
384
+ output_attentions: Optional[bool] = None,
385
+ output_hidden_states: Optional[bool] = None,
386
+ use_cache: Optional[bool] = None,
387
+ ):
388
+ return_dict = (
389
+ return_dict if return_dict is not None else self.config.return_dict
390
+ )
391
  use_cache = use_cache if use_cache is not None else self.config.use_cache
392
+ outputs = self.transformer(
393
+ input_ids=input_ids,
394
+ past_key_values=past_key_values,
395
+ attention_mask=attention_mask,
396
+ prefix_mask=prefix_mask,
397
+ sequence_id=sequence_id,
398
+ return_dict=return_dict,
399
+ output_attentions=output_attentions,
400
+ output_hidden_states=output_hidden_states,
401
+ use_cache=use_cache,
402
+ )
403
  logits = F.linear(outputs.last_hidden_state, self.transformer.wte.weight)
404
  if self.logit_scale is not None:
405
  if self.logit_scale == 0:
406
+ warnings.warn(
407
+ f"Multiplying logits by self.logit_scale={self.logit_scale!r}. This will produce uniform (uninformative) outputs."
408
+ )
409
  logits *= self.logit_scale
410
  loss = None
411
  if labels is not None:
412
  labels = torch.roll(labels, shifts=-1)
413
  labels[:, -1] = -100
414
+ loss = F.cross_entropy(
415
+ logits.view(-1, logits.size(-1)), labels.to(logits.device).view(-1)
416
+ )
417
+ return CausalLMOutputWithPast(
418
+ loss=loss,
419
+ logits=logits,
420
+ past_key_values=outputs.past_key_values,
421
+ hidden_states=outputs.hidden_states,
422
+ )
423
 
424
  def param_init_fn(self, module):
425
+ init_fn_name = self.config.init_config["name"]
426
+ MODEL_INIT_REGISTRY[init_fn_name](
427
+ module=module,
428
+ n_layers=self.config.n_layers,
429
+ d_model=self.config.d_model,
430
+ **self.config.init_config,
431
+ )
432
 
433
  def fsdp_wrap_fn(self, module):
434
  return isinstance(module, MPTBlock)
 
436
  def activation_checkpointing_fn(self, module):
437
  return isinstance(module, MPTBlock)
438
 
439
+ def prepare_inputs_for_generation(
440
+ self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs
441
+ ):
442
  if inputs_embeds is not None:
443
+ raise NotImplementedError("inputs_embeds is not implemented for MPT yet")
444
+ attention_mask = kwargs["attention_mask"].bool()
445
  if attention_mask[:, -1].sum() != attention_mask.shape[0]:
446
+ raise NotImplementedError(
447
+ "MPT does not support generation with right padding."
448
+ )
449
  if self.transformer.attn_uses_sequence_id and self.training:
450
  sequence_id = torch.zeros_like(input_ids[:1])
451
  else:
 
454
  input_ids = input_ids[:, -1].unsqueeze(-1)
455
  if self.transformer.prefix_lm:
456
  prefix_mask = torch.ones_like(attention_mask)
457
+ if kwargs.get("use_cache") == False:
458
+ raise NotImplementedError(
459
+ "MPT with prefix_lm=True does not support use_cache=False."
460
+ )
461
  else:
462
  prefix_mask = None
463
+ return {
464
+ "input_ids": input_ids,
465
+ "attention_mask": attention_mask,
466
+ "prefix_mask": prefix_mask,
467
+ "sequence_id": sequence_id,
468
+ "past_key_values": past_key_values,
469
+ "use_cache": kwargs.get("use_cache", True),
470
+ }
471
 
472
  @staticmethod
473
  def _reorder_cache(past_key_values, beam_idx):
 
478
  """
479
  reordered_past = []
480
  for layer_past in past_key_values:
481
+ reordered_past += [
482
+ tuple(
483
+ (past_state.index_select(0, beam_idx) for past_state in layer_past)
484
+ )
485
+ ]
486
+ return reordered_past