sahil2801 commited on
Commit
1e18faf
1 Parent(s): d95d373

add gradient checkpointing

Browse files
Files changed (1) hide show
  1. modeling_mpt.py +65 -16
modeling_mpt.py CHANGED
@@ -1,7 +1,3 @@
1
- """A simple, flexible implementation of a GPT model.
2
-
3
- Inspired by https://github.com/karpathy/minGPT/blob/master/mingpt/model.py
4
- """
5
  import math
6
  import warnings
7
  from typing import List, Optional, Tuple, Union
@@ -23,13 +19,19 @@ Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
23
  class MPTPreTrainedModel(PreTrainedModel):
24
  config_class = MPTConfig
25
  base_model_prefix = 'model'
26
- _no_split_modules=["MPTBlock"]
 
 
 
 
 
27
 
28
  class MPTModel(MPTPreTrainedModel):
29
 
30
  def __init__(self, config: MPTConfig):
31
  config._validate_config()
32
  super().__init__(config)
 
33
  self.attn_impl = config.attn_config['attn_impl']
34
  self.prefix_lm = config.attn_config['prefix_lm']
35
  self.attn_uses_sequence_id = config.attn_config['attn_uses_sequence_id']
@@ -47,7 +49,6 @@ class MPTModel(MPTPreTrainedModel):
47
  self.blocks = nn.ModuleList([MPTBlock(device=config.init_device, **config.to_dict()) for _ in range(config.n_layers)])
48
  self.norm_f = norm_class(config.d_model, device=config.init_device)
49
  if config.init_device != 'meta':
50
- print(f'You are using config.init_device={config.init_device!r}, but you can also use config.init_device="meta" with Composer + FSDP for fast initialization.')
51
  self.apply(self.param_init_fn)
52
  self.is_causal = not self.prefix_lm
53
  self._attn_bias_initialized = False
@@ -129,19 +130,48 @@ class MPTModel(MPTPreTrainedModel):
129
  attn_bias = attn_bias.masked_fill(cannot_attend, min_val)
130
  return attn_bias
131
 
132
- def forward(self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None):
133
  return_dict = return_dict if return_dict is not None else self.config.return_dict
134
  use_cache = use_cache if use_cache is not None else self.config.use_cache
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
  if attention_mask is not None:
136
  attention_mask = attention_mask.bool()
 
 
 
 
 
 
 
 
 
 
137
  if prefix_mask is not None:
138
  prefix_mask = prefix_mask.bool()
139
  if not return_dict:
140
  raise NotImplementedError('return_dict False is not implemented yet for MPT')
141
  if output_attentions:
142
  raise NotImplementedError('output_attentions is not implemented yet for MPT')
143
- if attention_mask is not None and attention_mask[:, 0].sum() != attention_mask.shape[0] and self.training:
144
- raise NotImplementedError('MPT does not support training with left padding.')
145
  if self.prefix_lm and prefix_mask is None:
146
  raise ValueError('prefix_mask is a required argument when MPT is configured with prefix_lm=True.')
147
  if self.training:
@@ -149,9 +179,8 @@ class MPTModel(MPTPreTrainedModel):
149
  raise ValueError('sequence_id is a required argument when MPT is configured with attn_uses_sequence_id=True ' + 'and the model is in train mode.')
150
  elif self.attn_uses_sequence_id is False and sequence_id is not None:
151
  warnings.warn('MPT received non-None input for `sequence_id` but is configured with attn_uses_sequence_id=False. ' + 'This input will be ignored. If you want the model to use `sequence_id`, set attn_uses_sequence_id to True.')
152
- S = input_ids.size(1)
153
  assert S <= self.config.max_seq_len, f'Cannot forward input with seq_len={S}, this model only supports seq_len<={self.config.max_seq_len}'
154
- tok_emb = self.wte(input_ids)
155
  if self.alibi:
156
  x = tok_emb
157
  else:
@@ -163,7 +192,7 @@ class MPTModel(MPTPreTrainedModel):
163
  if S + past_position > self.config.max_seq_len:
164
  raise ValueError(f'Cannot forward input with past sequence length {past_position} and current sequence length {S + 1}, this model only supports total sequence length <= {self.config.max_seq_len}.')
165
  pos = torch.arange(past_position, S + past_position, dtype=torch.long, device=input_ids.device).unsqueeze(0)
166
- if attention_mask is not None:
167
  pos = torch.clamp(pos - torch.cumsum((~attention_mask).to(torch.int32), dim=1)[:, past_position:], min=0)
168
  pos_emb = self.wpe(pos)
169
  x = tok_emb + pos_emb
@@ -176,13 +205,34 @@ class MPTModel(MPTPreTrainedModel):
176
  (attn_bias, attention_mask) = self._attn_bias(device=x.device, dtype=x.dtype, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id)
177
  if use_cache and past_key_values is None:
178
  past_key_values = [() for _ in range(self.config.n_layers)]
 
179
  all_hidden_states = () if output_hidden_states else None
180
  for (b_idx, block) in enumerate(self.blocks):
181
  if output_hidden_states:
182
  assert all_hidden_states is not None
183
  all_hidden_states = all_hidden_states + (x,)
184
  past_key_value = past_key_values[b_idx] if past_key_values is not None else None
185
- (x, past_key_value) = block(x, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=self.is_causal)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
186
  if past_key_values is not None:
187
  past_key_values[b_idx] = past_key_value
188
  x = self.norm_f(x)
@@ -233,10 +283,10 @@ class MPTForCausalLM(MPTPreTrainedModel):
233
  def get_decoder(self):
234
  return self.transformer
235
 
236
- def forward(self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, labels: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None):
237
  return_dict = return_dict if return_dict is not None else self.config.return_dict
238
  use_cache = use_cache if use_cache is not None else self.config.use_cache
239
- outputs = self.transformer(input_ids=input_ids, past_key_values=past_key_values, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id, return_dict=return_dict, output_attentions=output_attentions, output_hidden_states=output_hidden_states, use_cache=use_cache)
240
  logits = F.linear(outputs.last_hidden_state, self.transformer.wte.weight)
241
  if self.logit_scale is not None:
242
  if self.logit_scale == 0:
@@ -282,7 +332,6 @@ class MPTForCausalLM(MPTPreTrainedModel):
282
  @staticmethod
283
  def _reorder_cache(past_key_values, beam_idx):
284
  """Used by HuggingFace generate when using beam search with kv-caching.
285
-
286
  See https://github.com/huggingface/transformers/blob/3ec7a47664ebe40c40f4b722f6bb1cd30c3821ec/src/transformers/models/gpt2/modeling_gpt2.py#L1122-L1133
287
  for an example in transformers.
288
  """
 
 
 
 
 
1
  import math
2
  import warnings
3
  from typing import List, Optional, Tuple, Union
 
19
  class MPTPreTrainedModel(PreTrainedModel):
20
  config_class = MPTConfig
21
  base_model_prefix = 'model'
22
+ _no_split_modules = ["MPTBlock"]
23
+ supports_gradient_checkpointing = True
24
+
25
+ def _set_gradient_checkpointing(self, module, value=False):
26
+ if isinstance(module, MPTModel):
27
+ module.gradient_checkpointing = value
28
 
29
  class MPTModel(MPTPreTrainedModel):
30
 
31
  def __init__(self, config: MPTConfig):
32
  config._validate_config()
33
  super().__init__(config)
34
+ self.gradient_checkpointing = False
35
  self.attn_impl = config.attn_config['attn_impl']
36
  self.prefix_lm = config.attn_config['prefix_lm']
37
  self.attn_uses_sequence_id = config.attn_config['attn_uses_sequence_id']
 
49
  self.blocks = nn.ModuleList([MPTBlock(device=config.init_device, **config.to_dict()) for _ in range(config.n_layers)])
50
  self.norm_f = norm_class(config.d_model, device=config.init_device)
51
  if config.init_device != 'meta':
 
52
  self.apply(self.param_init_fn)
53
  self.is_causal = not self.prefix_lm
54
  self._attn_bias_initialized = False
 
130
  attn_bias = attn_bias.masked_fill(cannot_attend, min_val)
131
  return attn_bias
132
 
133
+ def forward(self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None, inputs_embeds: Optional[torch.FloatTensor] = None):
134
  return_dict = return_dict if return_dict is not None else self.config.return_dict
135
  use_cache = use_cache if use_cache is not None else self.config.use_cache
136
+ if self.gradient_checkpointing and self.training:
137
+ if use_cache:
138
+ use_cache = False
139
+ if input_ids is not None and inputs_embeds is not None:
140
+ raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
141
+ elif input_ids is not None:
142
+ batch_size, seq_length = input_ids.shape
143
+ elif inputs_embeds is not None:
144
+ batch_size, seq_length, _ = inputs_embeds.shape
145
+ else:
146
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
147
+
148
+ seq_length_with_past = seq_length
149
+ past_key_values_length = 0
150
+
151
+ if past_key_values is not None:
152
+ past_key_values_length = past_key_values[0][0].shape[2]
153
+ seq_length_with_past = seq_length_with_past + past_key_values_length
154
+
155
  if attention_mask is not None:
156
  attention_mask = attention_mask.bool()
157
+ else:
158
+ attention_mask = torch.ones(
159
+ (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
160
+ )
161
+
162
+ if inputs_embeds is None:
163
+ tok_emb = self.wte(input_ids)
164
+ else:
165
+ tok_emb = inputs_embeds
166
+
167
  if prefix_mask is not None:
168
  prefix_mask = prefix_mask.bool()
169
  if not return_dict:
170
  raise NotImplementedError('return_dict False is not implemented yet for MPT')
171
  if output_attentions:
172
  raise NotImplementedError('output_attentions is not implemented yet for MPT')
173
+ #if attention_mask is not None and attention_mask[:, 0].sum() != attention_mask.shape[0] and self.training:
174
+ # raise NotImplementedError('MPT does not support training with left padding.')
175
  if self.prefix_lm and prefix_mask is None:
176
  raise ValueError('prefix_mask is a required argument when MPT is configured with prefix_lm=True.')
177
  if self.training:
 
179
  raise ValueError('sequence_id is a required argument when MPT is configured with attn_uses_sequence_id=True ' + 'and the model is in train mode.')
180
  elif self.attn_uses_sequence_id is False and sequence_id is not None:
181
  warnings.warn('MPT received non-None input for `sequence_id` but is configured with attn_uses_sequence_id=False. ' + 'This input will be ignored. If you want the model to use `sequence_id`, set attn_uses_sequence_id to True.')
182
+ S = seq_length
183
  assert S <= self.config.max_seq_len, f'Cannot forward input with seq_len={S}, this model only supports seq_len<={self.config.max_seq_len}'
 
184
  if self.alibi:
185
  x = tok_emb
186
  else:
 
192
  if S + past_position > self.config.max_seq_len:
193
  raise ValueError(f'Cannot forward input with past sequence length {past_position} and current sequence length {S + 1}, this model only supports total sequence length <= {self.config.max_seq_len}.')
194
  pos = torch.arange(past_position, S + past_position, dtype=torch.long, device=input_ids.device).unsqueeze(0)
195
+ if attention_mask is not None and not self.training:
196
  pos = torch.clamp(pos - torch.cumsum((~attention_mask).to(torch.int32), dim=1)[:, past_position:], min=0)
197
  pos_emb = self.wpe(pos)
198
  x = tok_emb + pos_emb
 
205
  (attn_bias, attention_mask) = self._attn_bias(device=x.device, dtype=x.dtype, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id)
206
  if use_cache and past_key_values is None:
207
  past_key_values = [() for _ in range(self.config.n_layers)]
208
+
209
  all_hidden_states = () if output_hidden_states else None
210
  for (b_idx, block) in enumerate(self.blocks):
211
  if output_hidden_states:
212
  assert all_hidden_states is not None
213
  all_hidden_states = all_hidden_states + (x,)
214
  past_key_value = past_key_values[b_idx] if past_key_values is not None else None
215
+
216
+ if self.gradient_checkpointing and self.training:
217
+
218
+ def create_custom_forward(module):
219
+ def custom_forward(*inputs):
220
+ # None for past_key_value
221
+ return module(*inputs)
222
+
223
+ return custom_forward
224
+
225
+ (x, past_key_value) = torch.utils.checkpoint.checkpoint(
226
+ create_custom_forward(block),
227
+ x,
228
+ past_key_value,
229
+ attn_bias,
230
+ attention_mask,
231
+ self.is_causal,
232
+ )
233
+ else:
234
+ (x, past_key_value) = block(x, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=self.is_causal)
235
+
236
  if past_key_values is not None:
237
  past_key_values[b_idx] = past_key_value
238
  x = self.norm_f(x)
 
283
  def get_decoder(self):
284
  return self.transformer
285
 
286
+ def forward(self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, labels: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None, inputs_embeds: Optional[torch.FloatTensor] = None):
287
  return_dict = return_dict if return_dict is not None else self.config.return_dict
288
  use_cache = use_cache if use_cache is not None else self.config.use_cache
289
+ outputs = self.transformer(input_ids=input_ids, past_key_values=past_key_values, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id, return_dict=return_dict, output_attentions=output_attentions, output_hidden_states=output_hidden_states, use_cache=use_cache, inputs_embeds=inputs_embeds)
290
  logits = F.linear(outputs.last_hidden_state, self.transformer.wte.weight)
291
  if self.logit_scale is not None:
292
  if self.logit_scale == 0:
 
332
  @staticmethod
333
  def _reorder_cache(past_key_values, beam_idx):
334
  """Used by HuggingFace generate when using beam search with kv-caching.
 
335
  See https://github.com/huggingface/transformers/blob/3ec7a47664ebe40c40f4b722f6bb1cd30c3821ec/src/transformers/models/gpt2/modeling_gpt2.py#L1122-L1133
336
  for an example in transformers.
337
  """