eluzhnica commited on
Commit
a4a6419
·
1 Parent(s): 0c8ae92

Add gradient checkpointing

Browse files
Files changed (1) hide show
  1. modeling_mpt.py +32 -1
modeling_mpt.py CHANGED
@@ -30,11 +30,18 @@ class MPTPreTrainedModel(PreTrainedModel):
30
  base_model_prefix = 'model'
31
  _no_split_modules = ['MPTBlock']
32
 
 
 
 
 
 
 
33
  class MPTModel(MPTPreTrainedModel):
34
 
35
  def __init__(self, config: MPTConfig):
36
  config._validate_config()
37
  super().__init__(config)
 
38
  self.attn_impl = config.attn_config['attn_impl']
39
  self.prefix_lm = config.attn_config['prefix_lm']
40
  self.attn_uses_sequence_id = config.attn_config['attn_uses_sequence_id']
@@ -161,6 +168,10 @@ class MPTModel(MPTPreTrainedModel):
161
  raise ValueError('sequence_id is a required argument when MPT is configured with attn_uses_sequence_id=True ' + 'and the model is in train mode.')
162
  elif self.attn_uses_sequence_id is False and sequence_id is not None:
163
  warnings.warn('MPT received non-None input for `sequence_id` but is configured with attn_uses_sequence_id=False. ' + 'This input will be ignored. If you want the model to use `sequence_id`, set attn_uses_sequence_id to True.')
 
 
 
 
164
  S = input_ids.size(1)
165
  assert S <= self.config.max_seq_len, f'Cannot forward input with seq_len={S}, this model only supports seq_len<={self.config.max_seq_len}'
166
  tok_emb = self.wte(input_ids)
@@ -197,7 +208,27 @@ class MPTModel(MPTPreTrainedModel):
197
  assert all_hidden_states is not None
198
  all_hidden_states = all_hidden_states + (x,)
199
  past_key_value = past_key_values[b_idx] if past_key_values is not None else None
200
- (x, attn_weights, past_key_value) = block(x, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=self.is_causal)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
201
  if past_key_values is not None:
202
  past_key_values[b_idx] = past_key_value
203
  if output_attentions:
 
30
  base_model_prefix = 'model'
31
  _no_split_modules = ['MPTBlock']
32
 
33
+ supports_gradient_checkpointing = True
34
+
35
+ def _set_gradient_checkpointing(self, module, value=False):
36
+ if isinstance(module, MPTModel):
37
+ module.gradient_checkpointing = value
38
+
39
  class MPTModel(MPTPreTrainedModel):
40
 
41
  def __init__(self, config: MPTConfig):
42
  config._validate_config()
43
  super().__init__(config)
44
+ self.gradient_checkpointing = False
45
  self.attn_impl = config.attn_config['attn_impl']
46
  self.prefix_lm = config.attn_config['prefix_lm']
47
  self.attn_uses_sequence_id = config.attn_config['attn_uses_sequence_id']
 
168
  raise ValueError('sequence_id is a required argument when MPT is configured with attn_uses_sequence_id=True ' + 'and the model is in train mode.')
169
  elif self.attn_uses_sequence_id is False and sequence_id is not None:
170
  warnings.warn('MPT received non-None input for `sequence_id` but is configured with attn_uses_sequence_id=False. ' + 'This input will be ignored. If you want the model to use `sequence_id`, set attn_uses_sequence_id to True.')
171
+ if self.gradient_checkpointing and self.training:
172
+ if use_cache:
173
+ use_cache = False
174
+
175
  S = input_ids.size(1)
176
  assert S <= self.config.max_seq_len, f'Cannot forward input with seq_len={S}, this model only supports seq_len<={self.config.max_seq_len}'
177
  tok_emb = self.wte(input_ids)
 
208
  assert all_hidden_states is not None
209
  all_hidden_states = all_hidden_states + (x,)
210
  past_key_value = past_key_values[b_idx] if past_key_values is not None else None
211
+ if self.gradient_checkpointing and self.training:
212
+
213
+ def create_custom_forward(module):
214
+ def custom_forward(*inputs):
215
+ # None for past_key_value
216
+ return module(*inputs)
217
+
218
+ return custom_forward
219
+
220
+ (x, attn_weights, past_key_value) = torch.utils.checkpoint.checkpoint(
221
+ create_custom_forward(block),
222
+ x,
223
+ past_key_value,
224
+ attn_bias,
225
+ attention_mask,
226
+ self.is_causal,
227
+ )
228
+ else:
229
+ (x, attn_weights, past_key_value) = block(x, past_key_value=past_key_value, attn_bias=attn_bias,
230
+ attention_mask=attention_mask, is_causal=self.is_causal)
231
+
232
  if past_key_values is not None:
233
  past_key_values[b_idx] = past_key_value
234
  if output_attentions: