Reapply cekal/mpt-7b-peft-compatible

#4
Files changed (1) hide show
  1. modeling_mpt.py +71 -13
modeling_mpt.py CHANGED
@@ -33,13 +33,19 @@ log = logging.getLogger(__name__)
33
  class MPTPreTrainedModel(PreTrainedModel):
34
  config_class = MPTConfig
35
  base_model_prefix = 'model'
36
- _no_split_modules = ['MPTBlock']
 
 
 
 
 
37
 
38
  class MPTModel(MPTPreTrainedModel):
39
 
40
  def __init__(self, config: MPTConfig):
41
  config._validate_config()
42
  super().__init__(config)
 
43
  self.attn_impl = config.attn_config['attn_impl']
44
  self.prefix_lm = config.attn_config['prefix_lm']
45
  self.attn_uses_sequence_id = config.attn_config['attn_uses_sequence_id']
@@ -146,8 +152,37 @@ class MPTModel(MPTPreTrainedModel):
146
  def forward(self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None, inputs_embeds: Optional[torch.Tensor]=None) -> BaseModelOutputWithPast:
147
  return_dict = return_dict if return_dict is not None else self.config.return_dict
148
  use_cache = use_cache if use_cache is not None else self.config.use_cache
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
  if attention_mask is not None:
150
  attention_mask = attention_mask.bool()
 
 
 
 
 
 
 
 
 
 
151
  if prefix_mask is not None:
152
  prefix_mask = prefix_mask.bool()
153
  if not return_dict:
@@ -155,8 +190,8 @@ class MPTModel(MPTPreTrainedModel):
155
  if output_attentions:
156
  if self.attn_impl != 'torch':
157
  raise NotImplementedError('output_attentions is not implemented for MPT when using attn_impl `flash` or `triton`.')
158
- if self.training and attention_mask is not None and (attention_mask[:, 0].sum() != attention_mask.shape[0]):
159
- raise NotImplementedError('MPT does not support training with left padding.')
160
  if self.prefix_lm and prefix_mask is None:
161
  raise ValueError('prefix_mask is a required argument when MPT is configured with prefix_lm=True.')
162
  if inputs_embeds is not None:
@@ -166,7 +201,7 @@ class MPTModel(MPTPreTrainedModel):
166
  raise ValueError('sequence_id is a required argument when MPT is configured with attn_uses_sequence_id=True ' + 'and the model is in train mode.')
167
  elif self.attn_uses_sequence_id is False and sequence_id is not None:
168
  warnings.warn('MPT received non-None input for `sequence_id` but is configured with attn_uses_sequence_id=False. ' + 'This input will be ignored. If you want the model to use `sequence_id`, set attn_uses_sequence_id to True.')
169
- S = input_ids.size(1)
170
  assert S <= self.config.max_seq_len, f'Cannot forward input with seq_len={S}, this model only supports seq_len<={self.config.max_seq_len}'
171
  tok_emb = self.wte(input_ids)
172
  if self.learned_pos_emb:
@@ -180,7 +215,7 @@ class MPTModel(MPTPreTrainedModel):
180
  if S + past_position > self.config.max_seq_len:
181
  raise ValueError(f'Cannot forward input with past sequence length {past_position} and current sequence length ' + f'{S + 1}, this model only supports total sequence length <= {self.config.max_seq_len}.')
182
  pos = torch.arange(past_position, S + past_position, dtype=torch.long, device=input_ids.device).unsqueeze(0)
183
- if attention_mask is not None:
184
  pos = torch.clamp(pos - torch.cumsum((~attention_mask).to(torch.int32), dim=1)[:, past_position:], min=0)
185
  pos_emb = self.wpe(pos)
186
  x = tok_emb + pos_emb
@@ -196,6 +231,7 @@ class MPTModel(MPTPreTrainedModel):
196
  presents = () if use_cache else None
197
  if use_cache and past_key_values is None:
198
  past_key_values = [() for _ in range(self.config.n_layers)]
 
199
  all_hidden_states = () if output_hidden_states else None
200
  all_self_attns = () if output_attentions else None
201
  for (b_idx, block) in enumerate(self.blocks):
@@ -203,12 +239,34 @@ class MPTModel(MPTPreTrainedModel):
203
  assert all_hidden_states is not None
204
  all_hidden_states = all_hidden_states + (x,)
205
  past_key_value = past_key_values[b_idx] if past_key_values is not None else None
206
- (x, attn_weights, present) = block(x, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=self.is_causal, output_attentions=bool(output_attentions))
207
- if presents is not None:
208
- presents += (present,)
209
- if output_attentions:
210
- assert all_self_attns is not None
211
- all_self_attns = all_self_attns + (attn_weights,)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
212
  x = self.norm_f(x)
213
  if output_hidden_states:
214
  assert all_hidden_states is not None
@@ -271,7 +329,7 @@ class MPTForCausalLM(MPTPreTrainedModel):
271
  use_cache = use_cache if use_cache is not None else self.config.use_cache
272
  if inputs_embeds is not None:
273
  raise NotImplementedError('inputs_embeds has to be None (for hf/peft support).')
274
- outputs = self.transformer(input_ids=input_ids, past_key_values=past_key_values, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id, return_dict=return_dict, output_attentions=output_attentions, output_hidden_states=output_hidden_states, use_cache=use_cache)
275
  logits = self.transformer.wte(outputs.last_hidden_state.to(self.transformer.wte.weight.device), True)
276
  if self.logit_scale is not None:
277
  if self.logit_scale == 0:
@@ -324,4 +382,4 @@ class MPTForCausalLM(MPTPreTrainedModel):
324
  reordered_past = []
325
  for layer_past in past_key_values:
326
  reordered_past += [tuple((past_state.index_select(0, beam_idx) for past_state in layer_past))]
327
- return reordered_past
 
33
  class MPTPreTrainedModel(PreTrainedModel):
34
  config_class = MPTConfig
35
  base_model_prefix = 'model'
36
+ _no_split_modules = ["MPTBlock"]
37
+ supports_gradient_checkpointing = True
38
+
39
+ def _set_gradient_checkpointing(self, module, value=False):
40
+ if isinstance(module, MPTModel):
41
+ module.gradient_checkpointing = value
42
 
43
  class MPTModel(MPTPreTrainedModel):
44
 
45
  def __init__(self, config: MPTConfig):
46
  config._validate_config()
47
  super().__init__(config)
48
+ self.gradient_checkpointing = False
49
  self.attn_impl = config.attn_config['attn_impl']
50
  self.prefix_lm = config.attn_config['prefix_lm']
51
  self.attn_uses_sequence_id = config.attn_config['attn_uses_sequence_id']
 
152
  def forward(self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None, inputs_embeds: Optional[torch.Tensor]=None) -> BaseModelOutputWithPast:
153
  return_dict = return_dict if return_dict is not None else self.config.return_dict
154
  use_cache = use_cache if use_cache is not None else self.config.use_cache
155
+ if self.gradient_checkpointing and self.training:
156
+ if use_cache:
157
+ use_cache = False
158
+ if input_ids is not None and inputs_embeds is not None:
159
+ raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
160
+ elif input_ids is not None:
161
+ batch_size, seq_length = input_ids.shape
162
+ elif inputs_embeds is not None:
163
+ batch_size, seq_length, _ = inputs_embeds.shape
164
+ else:
165
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
166
+
167
+ seq_length_with_past = seq_length
168
+ past_key_values_length = 0
169
+
170
+ if past_key_values is not None:
171
+ past_key_values_length = past_key_values[0][0].shape[2]
172
+ seq_length_with_past = seq_length_with_past + past_key_values_length
173
+
174
  if attention_mask is not None:
175
  attention_mask = attention_mask.bool()
176
+ else:
177
+ attention_mask = torch.ones(
178
+ (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
179
+ )
180
+
181
+ if inputs_embeds is None:
182
+ tok_emb = self.wte(input_ids)
183
+ else:
184
+ tok_emb = inputs_embeds
185
+
186
  if prefix_mask is not None:
187
  prefix_mask = prefix_mask.bool()
188
  if not return_dict:
 
190
  if output_attentions:
191
  if self.attn_impl != 'torch':
192
  raise NotImplementedError('output_attentions is not implemented for MPT when using attn_impl `flash` or `triton`.')
193
+ #if self.training and attention_mask is not None and (attention_mask[:, 0].sum() != attention_mask.shape[0]):
194
+ # raise NotImplementedError('MPT does not support training with left padding.')
195
  if self.prefix_lm and prefix_mask is None:
196
  raise ValueError('prefix_mask is a required argument when MPT is configured with prefix_lm=True.')
197
  if inputs_embeds is not None:
 
201
  raise ValueError('sequence_id is a required argument when MPT is configured with attn_uses_sequence_id=True ' + 'and the model is in train mode.')
202
  elif self.attn_uses_sequence_id is False and sequence_id is not None:
203
  warnings.warn('MPT received non-None input for `sequence_id` but is configured with attn_uses_sequence_id=False. ' + 'This input will be ignored. If you want the model to use `sequence_id`, set attn_uses_sequence_id to True.')
204
+ S = seq_length
205
  assert S <= self.config.max_seq_len, f'Cannot forward input with seq_len={S}, this model only supports seq_len<={self.config.max_seq_len}'
206
  tok_emb = self.wte(input_ids)
207
  if self.learned_pos_emb:
 
215
  if S + past_position > self.config.max_seq_len:
216
  raise ValueError(f'Cannot forward input with past sequence length {past_position} and current sequence length ' + f'{S + 1}, this model only supports total sequence length <= {self.config.max_seq_len}.')
217
  pos = torch.arange(past_position, S + past_position, dtype=torch.long, device=input_ids.device).unsqueeze(0)
218
+ if attention_mask is not None and not self.training:
219
  pos = torch.clamp(pos - torch.cumsum((~attention_mask).to(torch.int32), dim=1)[:, past_position:], min=0)
220
  pos_emb = self.wpe(pos)
221
  x = tok_emb + pos_emb
 
231
  presents = () if use_cache else None
232
  if use_cache and past_key_values is None:
233
  past_key_values = [() for _ in range(self.config.n_layers)]
234
+
235
  all_hidden_states = () if output_hidden_states else None
236
  all_self_attns = () if output_attentions else None
237
  for (b_idx, block) in enumerate(self.blocks):
 
239
  assert all_hidden_states is not None
240
  all_hidden_states = all_hidden_states + (x,)
241
  past_key_value = past_key_values[b_idx] if past_key_values is not None else None
242
+ if self.gradient_checkpointing and self.training:
243
+
244
+ def create_custom_forward(module):
245
+ def custom_forward(*inputs):
246
+ # None for past_key_value
247
+ return module(*inputs)
248
+
249
+ return custom_forward
250
+
251
+ (x, past_key_value) = torch.utils.checkpoint.checkpoint(
252
+ create_custom_forward(block),
253
+ x,
254
+ past_key_value,
255
+ attn_bias,
256
+ attention_mask,
257
+ self.is_causal,
258
+ )
259
+ if past_key_values is not None:
260
+ past_key_values[b_idx] = past_key_value
261
+ else:
262
+ (x, attn_weights, present) = block(x, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=self.is_causal, output_attentions=bool(output_attentions))
263
+ if presents is not None:
264
+ presents += (present,)
265
+ if output_attentions:
266
+ assert all_self_attns is not None
267
+ all_self_attns = all_self_attns + (attn_weights,)
268
+
269
+
270
  x = self.norm_f(x)
271
  if output_hidden_states:
272
  assert all_hidden_states is not None
 
329
  use_cache = use_cache if use_cache is not None else self.config.use_cache
330
  if inputs_embeds is not None:
331
  raise NotImplementedError('inputs_embeds has to be None (for hf/peft support).')
332
+ outputs = self.transformer(input_ids=input_ids, past_key_values=past_key_values, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id, return_dict=return_dict, output_attentions=output_attentions, output_hidden_states=output_hidden_states, use_cache=use_cache, inputs_embeds=inputs_embeds)
333
  logits = self.transformer.wte(outputs.last_hidden_state.to(self.transformer.wte.weight.device), True)
334
  if self.logit_scale is not None:
335
  if self.logit_scale == 0:
 
382
  reordered_past = []
383
  for layer_past in past_key_values:
384
  reordered_past += [tuple((past_state.index_select(0, beam_idx) for past_state in layer_past))]
385
+ return reordered_past