Hyungtae Kim commited on
Commit
8fb2de8
1 Parent(s): 0261af7

Apply changes from cekal/mpt-7b-peft-compatible

Browse files
Files changed (2) hide show
  1. README.md +6 -0
  2. modeling_mpt.py +60 -6
README.md CHANGED
@@ -14,6 +14,12 @@ datasets:
14
  inference: false
15
  ---
16
 
 
 
 
 
 
 
17
  # MPT-30B
18
 
19
  MPT-30B is a decoder-style transformer pretrained from scratch on 1T tokens of English text and code.
 
14
  inference: false
15
  ---
16
 
17
+ ### Attribution
18
+
19
+ This model is derived from [MosaicML's MPT-30B model](https://huggingface.co/mosaicml/mpt-30b/tree/main), with changes from
20
+ [cekal/mpt-7b-peft-compatible](https://huggingface.co/cekal/mpt-7b-peft-compatible) applied; each licensed under the
21
+ Apache License, version 2.0.
22
+
23
  # MPT-30B
24
 
25
  MPT-30B is a decoder-style transformer pretrained from scratch on 1T tokens of English text and code.
modeling_mpt.py CHANGED
@@ -28,13 +28,19 @@ Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
28
  class MPTPreTrainedModel(PreTrainedModel):
29
  config_class = MPTConfig
30
  base_model_prefix = 'model'
31
- _no_split_modules = ['MPTBlock']
 
 
 
 
 
32
 
33
  class MPTModel(MPTPreTrainedModel):
34
 
35
  def __init__(self, config: MPTConfig):
36
  config._validate_config()
37
  super().__init__(config)
 
38
  self.attn_impl = config.attn_config['attn_impl']
39
  self.prefix_lm = config.attn_config['prefix_lm']
40
  self.attn_uses_sequence_id = config.attn_config['attn_uses_sequence_id']
@@ -143,8 +149,37 @@ class MPTModel(MPTPreTrainedModel):
143
  def forward(self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None, inputs_embeds: Optional[torch.Tensor]=None):
144
  return_dict = return_dict if return_dict is not None else self.config.return_dict
145
  use_cache = use_cache if use_cache is not None else self.config.use_cache
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
  if attention_mask is not None:
147
  attention_mask = attention_mask.bool()
 
 
 
 
 
 
 
 
 
 
148
  if prefix_mask is not None:
149
  prefix_mask = prefix_mask.bool()
150
  if not return_dict:
@@ -152,8 +187,8 @@ class MPTModel(MPTPreTrainedModel):
152
  if output_attentions:
153
  if self.attn_impl != 'torch':
154
  raise NotImplementedError('output_attentions is not implemented for MPT when using attn_impl `flash` or `triton`.')
155
- if attention_mask is not None and attention_mask[:, 0].sum() != attention_mask.shape[0] and self.training:
156
- raise NotImplementedError('MPT does not support training with left padding.')
157
  if self.prefix_lm and prefix_mask is None:
158
  raise ValueError('prefix_mask is a required argument when MPT is configured with prefix_lm=True.')
159
  if inputs_embeds is not None:
@@ -163,7 +198,7 @@ class MPTModel(MPTPreTrainedModel):
163
  raise ValueError('sequence_id is a required argument when MPT is configured with attn_uses_sequence_id=True ' + 'and the model is in train mode.')
164
  elif self.attn_uses_sequence_id is False and sequence_id is not None:
165
  warnings.warn('MPT received non-None input for `sequence_id` but is configured with attn_uses_sequence_id=False. ' + 'This input will be ignored. If you want the model to use `sequence_id`, set attn_uses_sequence_id to True.')
166
- S = input_ids.size(1)
167
  assert S <= self.config.max_seq_len, f'Cannot forward input with seq_len={S}, this model only supports seq_len<={self.config.max_seq_len}'
168
  tok_emb = self.wte(input_ids)
169
  if self.alibi:
@@ -179,7 +214,7 @@ class MPTModel(MPTPreTrainedModel):
179
  if S + past_position > self.config.max_seq_len:
180
  raise ValueError(f'Cannot forward input with past sequence length {past_position} and current sequence length {S + 1}, this model only supports total sequence length <= {self.config.max_seq_len}.')
181
  pos = torch.arange(past_position, S + past_position, dtype=torch.long, device=input_ids.device).unsqueeze(0)
182
- if attention_mask is not None:
183
  pos = torch.clamp(pos - torch.cumsum((~attention_mask).to(torch.int32), dim=1)[:, past_position:], min=0)
184
  pos_emb = self.wpe(pos)
185
  x = tok_emb + pos_emb
@@ -199,7 +234,26 @@ class MPTModel(MPTPreTrainedModel):
199
  assert all_hidden_states is not None
200
  all_hidden_states = all_hidden_states + (x,)
201
  past_key_value = past_key_values[b_idx] if past_key_values is not None else None
202
- (x, attn_weights, past_key_value) = block(x, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=self.is_causal)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
203
  if past_key_values is not None:
204
  past_key_values[b_idx] = past_key_value
205
  if output_attentions:
 
28
  class MPTPreTrainedModel(PreTrainedModel):
29
  config_class = MPTConfig
30
  base_model_prefix = 'model'
31
+ _no_split_modules = ["MPTBlock"]
32
+ supports_gradient_checkpointing = True
33
+
34
+ def _set_gradient_checkpointing(self, module, value=False):
35
+ if isinstance(module, MPTModel):
36
+ module.gradient_checkpointing = value
37
 
38
  class MPTModel(MPTPreTrainedModel):
39
 
40
  def __init__(self, config: MPTConfig):
41
  config._validate_config()
42
  super().__init__(config)
43
+ self.gradient_checkpointing = False
44
  self.attn_impl = config.attn_config['attn_impl']
45
  self.prefix_lm = config.attn_config['prefix_lm']
46
  self.attn_uses_sequence_id = config.attn_config['attn_uses_sequence_id']
 
149
  def forward(self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None, inputs_embeds: Optional[torch.Tensor]=None):
150
  return_dict = return_dict if return_dict is not None else self.config.return_dict
151
  use_cache = use_cache if use_cache is not None else self.config.use_cache
152
+ if self.gradient_checkpointing and self.training:
153
+ if use_cache:
154
+ use_cache = False
155
+ if input_ids is not None and inputs_embeds is not None:
156
+ raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
157
+ elif input_ids is not None:
158
+ batch_size, seq_length = input_ids.shape
159
+ elif inputs_embeds is not None:
160
+ batch_size, seq_length, _ = inputs_embeds.shape
161
+ else:
162
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
163
+
164
+ seq_length_with_past = seq_length
165
+ past_key_values_length = 0
166
+
167
+ if past_key_values is not None:
168
+ past_key_values_length = past_key_values[0][0].shape[2]
169
+ seq_length_with_past = seq_length_with_past + past_key_values_length
170
+
171
  if attention_mask is not None:
172
  attention_mask = attention_mask.bool()
173
+ else:
174
+ attention_mask = torch.ones(
175
+ (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
176
+ )
177
+
178
+ if inputs_embeds is None:
179
+ tok_emb = self.wte(input_ids)
180
+ else:
181
+ tok_emb = inputs_embeds
182
+
183
  if prefix_mask is not None:
184
  prefix_mask = prefix_mask.bool()
185
  if not return_dict:
 
187
  if output_attentions:
188
  if self.attn_impl != 'torch':
189
  raise NotImplementedError('output_attentions is not implemented for MPT when using attn_impl `flash` or `triton`.')
190
+ # if attention_mask is not None and attention_mask[:, 0].sum() != attention_mask.shape[0] and self.training:
191
+ # raise NotImplementedError('MPT does not support training with left padding.')
192
  if self.prefix_lm and prefix_mask is None:
193
  raise ValueError('prefix_mask is a required argument when MPT is configured with prefix_lm=True.')
194
  if inputs_embeds is not None:
 
198
  raise ValueError('sequence_id is a required argument when MPT is configured with attn_uses_sequence_id=True ' + 'and the model is in train mode.')
199
  elif self.attn_uses_sequence_id is False and sequence_id is not None:
200
  warnings.warn('MPT received non-None input for `sequence_id` but is configured with attn_uses_sequence_id=False. ' + 'This input will be ignored. If you want the model to use `sequence_id`, set attn_uses_sequence_id to True.')
201
+ S = seq_length
202
  assert S <= self.config.max_seq_len, f'Cannot forward input with seq_len={S}, this model only supports seq_len<={self.config.max_seq_len}'
203
  tok_emb = self.wte(input_ids)
204
  if self.alibi:
 
214
  if S + past_position > self.config.max_seq_len:
215
  raise ValueError(f'Cannot forward input with past sequence length {past_position} and current sequence length {S + 1}, this model only supports total sequence length <= {self.config.max_seq_len}.')
216
  pos = torch.arange(past_position, S + past_position, dtype=torch.long, device=input_ids.device).unsqueeze(0)
217
+ if attention_mask is not None and not self.training:
218
  pos = torch.clamp(pos - torch.cumsum((~attention_mask).to(torch.int32), dim=1)[:, past_position:], min=0)
219
  pos_emb = self.wpe(pos)
220
  x = tok_emb + pos_emb
 
234
  assert all_hidden_states is not None
235
  all_hidden_states = all_hidden_states + (x,)
236
  past_key_value = past_key_values[b_idx] if past_key_values is not None else None
237
+ if self.gradient_checkpointing and self.training:
238
+
239
+ def create_custom_forward(module):
240
+ def custom_forward(*inputs):
241
+ # None for past_key_value
242
+ return module(*inputs)
243
+
244
+ return custom_forward
245
+
246
+ (x, attn_weights, past_key_value) = torch.utils.checkpoint.checkpoint(
247
+ create_custom_forward(block),
248
+ x,
249
+ past_key_value,
250
+ attn_bias,
251
+ attention_mask,
252
+ self.is_causal,
253
+ )
254
+ else:
255
+ (x, attn_weights, past_key_value) = block(x, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=self.is_causal)
256
+
257
  if past_key_values is not None:
258
  past_key_values[b_idx] = past_key_value
259
  if output_attentions: