cekal commited on
Commit
a5eab52
1 Parent(s): f2f3202

Update modeling_mpt.py

Browse files
Files changed (1) hide show
  1. modeling_mpt.py +64 -8
modeling_mpt.py CHANGED
@@ -23,12 +23,19 @@ Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
23
  class MPTPreTrainedModel(PreTrainedModel):
24
  config_class = MPTConfig
25
  base_model_prefix = 'model'
 
 
 
 
 
 
26
 
27
  class MPTModel(MPTPreTrainedModel):
28
 
29
  def __init__(self, config: MPTConfig):
30
  config._validate_config()
31
  super().__init__(config)
 
32
  self.attn_impl = config.attn_config['attn_impl']
33
  self.prefix_lm = config.attn_config['prefix_lm']
34
  self.attn_uses_sequence_id = config.attn_config['attn_uses_sequence_id']
@@ -127,19 +134,48 @@ class MPTModel(MPTPreTrainedModel):
127
  attn_bias = attn_bias.masked_fill(cannot_attend, min_val)
128
  return attn_bias
129
 
130
- def forward(self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None):
131
  return_dict = return_dict if return_dict is not None else self.config.return_dict
132
  use_cache = use_cache if use_cache is not None else self.config.use_cache
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
  if attention_mask is not None:
134
  attention_mask = attention_mask.bool()
 
 
 
 
 
 
 
 
 
 
135
  if prefix_mask is not None:
136
  prefix_mask = prefix_mask.bool()
137
  if not return_dict:
138
  raise NotImplementedError('return_dict False is not implemented yet for MPT')
139
  if output_attentions:
140
  raise NotImplementedError('output_attentions is not implemented yet for MPT')
141
- if attention_mask is not None and attention_mask[:, 0].sum() != attention_mask.shape[0] and self.training:
142
- raise NotImplementedError('MPT does not support training with left padding.')
143
  if self.prefix_lm and prefix_mask is None:
144
  raise ValueError('prefix_mask is a required argument when MPT is configured with prefix_lm=True.')
145
  if self.training:
@@ -147,9 +183,8 @@ class MPTModel(MPTPreTrainedModel):
147
  raise ValueError('sequence_id is a required argument when MPT is configured with attn_uses_sequence_id=True ' + 'and the model is in train mode.')
148
  elif self.attn_uses_sequence_id is False and sequence_id is not None:
149
  warnings.warn('MPT received non-None input for `sequence_id` but is configured with attn_uses_sequence_id=False. ' + 'This input will be ignored. If you want the model to use `sequence_id`, set attn_uses_sequence_id to True.')
150
- S = input_ids.size(1)
151
  assert S <= self.config.max_seq_len, f'Cannot forward input with seq_len={S}, this model only supports seq_len<={self.config.max_seq_len}'
152
- tok_emb = self.wte(input_ids)
153
  if self.alibi:
154
  x = tok_emb
155
  else:
@@ -161,7 +196,7 @@ class MPTModel(MPTPreTrainedModel):
161
  if S + past_position > self.config.max_seq_len:
162
  raise ValueError(f'Cannot forward input with past sequence length {past_position} and current sequence length {S + 1}, this model only supports total sequence length <= {self.config.max_seq_len}.')
163
  pos = torch.arange(past_position, S + past_position, dtype=torch.long, device=input_ids.device).unsqueeze(0)
164
- if attention_mask is not None:
165
  pos = torch.clamp(pos - torch.cumsum((~attention_mask).to(torch.int32), dim=1)[:, past_position:], min=0)
166
  pos_emb = self.wpe(pos)
167
  x = tok_emb + pos_emb
@@ -174,13 +209,34 @@ class MPTModel(MPTPreTrainedModel):
174
  (attn_bias, attention_mask) = self._attn_bias(device=x.device, dtype=x.dtype, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id)
175
  if use_cache and past_key_values is None:
176
  past_key_values = [() for _ in range(self.config.n_layers)]
 
177
  all_hidden_states = () if output_hidden_states else None
178
  for (b_idx, block) in enumerate(self.blocks):
179
  if output_hidden_states:
180
  assert all_hidden_states is not None
181
  all_hidden_states = all_hidden_states + (x,)
182
  past_key_value = past_key_values[b_idx] if past_key_values is not None else None
183
- (x, past_key_value) = block(x, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=self.is_causal)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
184
  if past_key_values is not None:
185
  past_key_values[b_idx] = past_key_value
186
  x = self.norm_f(x)
@@ -234,7 +290,7 @@ class MPTForCausalLM(MPTPreTrainedModel):
234
  def forward(self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, labels: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None, inputs_embeds: Optional[torch.FloatTensor] = None):
235
  return_dict = return_dict if return_dict is not None else self.config.return_dict
236
  use_cache = use_cache if use_cache is not None else self.config.use_cache
237
- outputs = self.transformer(input_ids=input_ids, past_key_values=past_key_values, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id, return_dict=return_dict, output_attentions=output_attentions, output_hidden_states=output_hidden_states, use_cache=use_cache)
238
  logits = F.linear(outputs.last_hidden_state, self.transformer.wte.weight)
239
  if self.logit_scale is not None:
240
  if self.logit_scale == 0:
 
23
  class MPTPreTrainedModel(PreTrainedModel):
24
  config_class = MPTConfig
25
  base_model_prefix = 'model'
26
+ _no_split_modules = ["MPTBlock"]
27
+ supports_gradient_checkpointing = True
28
+
29
+ def _set_gradient_checkpointing(self, module, value=False):
30
+ if isinstance(module, MPTModel):
31
+ module.gradient_checkpointing = value
32
 
33
  class MPTModel(MPTPreTrainedModel):
34
 
35
  def __init__(self, config: MPTConfig):
36
  config._validate_config()
37
  super().__init__(config)
38
+ self.gradient_checkpointing = False
39
  self.attn_impl = config.attn_config['attn_impl']
40
  self.prefix_lm = config.attn_config['prefix_lm']
41
  self.attn_uses_sequence_id = config.attn_config['attn_uses_sequence_id']
 
134
  attn_bias = attn_bias.masked_fill(cannot_attend, min_val)
135
  return attn_bias
136
 
137
+ def forward(self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None, inputs_embeds: Optional[torch.FloatTensor] = None):
138
  return_dict = return_dict if return_dict is not None else self.config.return_dict
139
  use_cache = use_cache if use_cache is not None else self.config.use_cache
140
+ if self.gradient_checkpointing and self.training:
141
+ if use_cache:
142
+ use_cache = False
143
+ if input_ids is not None and inputs_embeds is not None:
144
+ raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
145
+ elif input_ids is not None:
146
+ batch_size, seq_length = input_ids.shape
147
+ elif inputs_embeds is not None:
148
+ batch_size, seq_length, _ = inputs_embeds.shape
149
+ else:
150
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
151
+
152
+ seq_length_with_past = seq_length
153
+ past_key_values_length = 0
154
+
155
+ if past_key_values is not None:
156
+ past_key_values_length = past_key_values[0][0].shape[2]
157
+ seq_length_with_past = seq_length_with_past + past_key_values_length
158
+
159
  if attention_mask is not None:
160
  attention_mask = attention_mask.bool()
161
+ else:
162
+ attention_mask = torch.ones(
163
+ (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
164
+ )
165
+
166
+ if inputs_embeds is None:
167
+ tok_emb = self.wte(input_ids)
168
+ else:
169
+ tok_emb = inputs_embeds
170
+
171
  if prefix_mask is not None:
172
  prefix_mask = prefix_mask.bool()
173
  if not return_dict:
174
  raise NotImplementedError('return_dict False is not implemented yet for MPT')
175
  if output_attentions:
176
  raise NotImplementedError('output_attentions is not implemented yet for MPT')
177
+ #if attention_mask is not None and attention_mask[:, 0].sum() != attention_mask.shape[0] and self.training:
178
+ # raise NotImplementedError('MPT does not support training with left padding.')
179
  if self.prefix_lm and prefix_mask is None:
180
  raise ValueError('prefix_mask is a required argument when MPT is configured with prefix_lm=True.')
181
  if self.training:
 
183
  raise ValueError('sequence_id is a required argument when MPT is configured with attn_uses_sequence_id=True ' + 'and the model is in train mode.')
184
  elif self.attn_uses_sequence_id is False and sequence_id is not None:
185
  warnings.warn('MPT received non-None input for `sequence_id` but is configured with attn_uses_sequence_id=False. ' + 'This input will be ignored. If you want the model to use `sequence_id`, set attn_uses_sequence_id to True.')
186
+ S = seq_length
187
  assert S <= self.config.max_seq_len, f'Cannot forward input with seq_len={S}, this model only supports seq_len<={self.config.max_seq_len}'
 
188
  if self.alibi:
189
  x = tok_emb
190
  else:
 
196
  if S + past_position > self.config.max_seq_len:
197
  raise ValueError(f'Cannot forward input with past sequence length {past_position} and current sequence length {S + 1}, this model only supports total sequence length <= {self.config.max_seq_len}.')
198
  pos = torch.arange(past_position, S + past_position, dtype=torch.long, device=input_ids.device).unsqueeze(0)
199
+ if attention_mask is not None and not self.training:
200
  pos = torch.clamp(pos - torch.cumsum((~attention_mask).to(torch.int32), dim=1)[:, past_position:], min=0)
201
  pos_emb = self.wpe(pos)
202
  x = tok_emb + pos_emb
 
209
  (attn_bias, attention_mask) = self._attn_bias(device=x.device, dtype=x.dtype, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id)
210
  if use_cache and past_key_values is None:
211
  past_key_values = [() for _ in range(self.config.n_layers)]
212
+
213
  all_hidden_states = () if output_hidden_states else None
214
  for (b_idx, block) in enumerate(self.blocks):
215
  if output_hidden_states:
216
  assert all_hidden_states is not None
217
  all_hidden_states = all_hidden_states + (x,)
218
  past_key_value = past_key_values[b_idx] if past_key_values is not None else None
219
+
220
+ if self.gradient_checkpointing and self.training:
221
+
222
+ def create_custom_forward(module):
223
+ def custom_forward(*inputs):
224
+ # None for past_key_value
225
+ return module(*inputs)
226
+
227
+ return custom_forward
228
+
229
+ (x, past_key_value) = torch.utils.checkpoint.checkpoint(
230
+ create_custom_forward(block),
231
+ x,
232
+ past_key_value,
233
+ attn_bias,
234
+ attention_mask,
235
+ self.is_causal,
236
+ )
237
+ else:
238
+ (x, past_key_value) = block(x, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=self.is_causal)
239
+
240
  if past_key_values is not None:
241
  past_key_values[b_idx] = past_key_value
242
  x = self.norm_f(x)
 
290
  def forward(self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, labels: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None, inputs_embeds: Optional[torch.FloatTensor] = None):
291
  return_dict = return_dict if return_dict is not None else self.config.return_dict
292
  use_cache = use_cache if use_cache is not None else self.config.use_cache
293
+ outputs = self.transformer(input_ids=input_ids, past_key_values=past_key_values, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id, return_dict=return_dict, output_attentions=output_attentions, output_hidden_states=output_hidden_states, use_cache=use_cache, inputs_embeds=inputs_embeds)
294
  logits = F.linear(outputs.last_hidden_state, self.transformer.wte.weight)
295
  if self.logit_scale is not None:
296
  if self.logit_scale == 0: