feat: selective activation checkpointing

#16
Files changed (2) hide show
  1. configuration_bert.py +21 -1
  2. modeling_bert.py +12 -8
configuration_bert.py CHANGED
@@ -55,6 +55,10 @@ class JinaBertConfig(PretrainedConfig):
55
  layer_norm_eps (`float`, *optional*, defaults to 1e-12):
56
  The epsilon used by the layer normalization layers.
57
  window_size (`tuple`, *optional*, defaults to `(-1, -1)`): If not the default, use local attention
 
 
 
 
58
  """
59
 
60
  model_type = "bert"
@@ -86,6 +90,7 @@ class JinaBertConfig(PretrainedConfig):
86
  emb_pooler=None,
87
  classifier_dropout=None,
88
  num_loras=5,
 
89
  **kwargs,
90
  ):
91
  assert 'position_embedding_type' not in kwargs
@@ -95,6 +100,20 @@ class JinaBertConfig(PretrainedConfig):
95
  if mlp_type == 'fused_mlp' and hidden_act not in ["gelu_new", "gelu_fast", "gelu_pytorch_tanh"]:
96
  raise ValueError('Fused MLP only supports approximate gelu')
97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  self.vocab_size = vocab_size
99
  self.hidden_size = hidden_size
100
  self.num_hidden_layers = num_hidden_layers
@@ -118,4 +137,5 @@ class JinaBertConfig(PretrainedConfig):
118
  self.use_qk_norm = use_qk_norm
119
  self.emb_pooler = emb_pooler
120
  self.classifier_dropout = classifier_dropout
121
- self.num_loras = num_loras
 
 
55
  layer_norm_eps (`float`, *optional*, defaults to 1e-12):
56
  The epsilon used by the layer normalization layers.
57
  window_size (`tuple`, *optional*, defaults to `(-1, -1)`): If not the default, use local attention
58
+ activation_checkpoint_lvl (`int`, *optional*, defaults to `100`): How many layers to activation-checkpoint.
59
+ If larger than 0, the MLP activation checkpointing level is expected to be 0 for the first
60
+ `activation_checkpoint_lvl` layers. The activation checkpointing will only come into effect
61
+ after `model.gradient_checkpointing_enable()` is called.
62
  """
63
 
64
  model_type = "bert"
 
90
  emb_pooler=None,
91
  classifier_dropout=None,
92
  num_loras=5,
93
+ activation_checkpoint_lvl=100,
94
  **kwargs,
95
  ):
96
  assert 'position_embedding_type' not in kwargs
 
100
  if mlp_type == 'fused_mlp' and hidden_act not in ["gelu_new", "gelu_fast", "gelu_pytorch_tanh"]:
101
  raise ValueError('Fused MLP only supports approximate gelu')
102
 
103
+ if mlp_checkpoint_lvl != 0 and mlp_type != 'fused_mlp':
104
+ raise ValueError('MLP checkpointing only available for `fused_mlp`')
105
+
106
+ if activation_checkpoint_lvl > 0 and isinstance(mlp_checkpoint_lvl, int) and mlp_checkpoint_lvl > 0:
107
+ raise ValueError('Trying to use layer-wise activation checkpointing and MLP-checkpointing '
108
+ 'in every layer simultaneously. Either only use one of the techniques, '
109
+ 'or specify layer-wise MLP checkpointing.')
110
+ elif activation_checkpoint_lvl > 0 and mlp_checkpoint_lvl > 0:
111
+ for layer_idx, mlp_lvl in enumerate(mlp_checkpoint_lvl):
112
+ if layer_idx < activation_checkpoint_lvl and mlp_lvl > 0:
113
+ raise ValueError(f'Layer {layer_idx} is being checkpointed as a whole and its MLP '
114
+ f'is being checkpointed. Either remove MLP checkpointing for this layer '
115
+ f'or reduce the `activation_checkpoint_lvl` appropriately')
116
+
117
  self.vocab_size = vocab_size
118
  self.hidden_size = hidden_size
119
  self.num_hidden_layers = num_hidden_layers
 
137
  self.use_qk_norm = use_qk_norm
138
  self.emb_pooler = emb_pooler
139
  self.classifier_dropout = classifier_dropout
140
+ self.num_loras = num_loras
141
+ self.activation_checkpoint_lvl = activation_checkpoint_lvl
modeling_bert.py CHANGED
@@ -180,13 +180,17 @@ class BertEncoder(nn.Module):
180
  [create_block(config, layer_idx=i) for i in range(config.num_hidden_layers)]
181
  )
182
  self._grad_checkpointing = False
 
183
 
184
  @property
185
  def gradient_checkpointing(self):
186
  return self._grad_checkpointing
187
 
188
  @gradient_checkpointing.setter
189
- def gradient_checkpointing(self, value):
 
 
 
190
  self._grad_checkpointing = value
191
 
192
  def forward(self, hidden_states, key_padding_mask=None, subset_mask=None):
@@ -198,8 +202,8 @@ class BertEncoder(nn.Module):
198
  mixer_kwargs = (
199
  {"key_padding_mask": key_padding_mask.bool()} if key_padding_mask is not None else None
200
  )
201
- for layer in self.layers:
202
- if self._grad_checkpointing:
203
  hidden_states = torch.utils.checkpoint.checkpoint(
204
  layer,
205
  hidden_states,
@@ -217,8 +221,8 @@ class BertEncoder(nn.Module):
217
  )
218
  mixer_kwargs = {"cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen_in_batch}
219
  if subset_mask is None:
220
- for layer in self.layers:
221
- if self._grad_checkpointing:
222
  hidden_states = torch.utils.checkpoint.checkpoint(
223
  layer,
224
  hidden_states,
@@ -229,8 +233,8 @@ class BertEncoder(nn.Module):
229
  hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
230
  hidden_states = pad_input(hidden_states, indices, batch, seqlen)
231
  else:
232
- for layer in self.layers[:-1]:
233
- if self._grad_checkpointing:
234
  hidden_states = torch.utils.checkpoint.checkpoint(
235
  layer,
236
  hidden_states,
@@ -264,7 +268,7 @@ class BertEncoder(nn.Module):
264
  "cu_seqlens_k": cu_seqlens,
265
  "max_seqlen_k": max_seqlen_in_batch,
266
  }
267
- if self._grad_checkpointing:
268
  torch.utils.checkpoint.checkpoint(
269
  self.layers[-1],
270
  hidden_states_subset,
 
180
  [create_block(config, layer_idx=i) for i in range(config.num_hidden_layers)]
181
  )
182
  self._grad_checkpointing = False
183
+ self._num_checkpointed_layers = config.activation_checkpoint_lvl
184
 
185
  @property
186
  def gradient_checkpointing(self):
187
  return self._grad_checkpointing
188
 
189
  @gradient_checkpointing.setter
190
+ def gradient_checkpointing(self, value: bool):
191
+ if value and self._num_checkpointed_layers <= 0:
192
+ raise ValueError('Trying to use activation checkpointing, but `activation_checkpoint_lvl`'
193
+ 'is set to zero.')
194
  self._grad_checkpointing = value
195
 
196
  def forward(self, hidden_states, key_padding_mask=None, subset_mask=None):
 
202
  mixer_kwargs = (
203
  {"key_padding_mask": key_padding_mask.bool()} if key_padding_mask is not None else None
204
  )
205
+ for idx, layer in enumerate(self.layers):
206
+ if self._grad_checkpointing and idx < self._num_checkpointed_layers:
207
  hidden_states = torch.utils.checkpoint.checkpoint(
208
  layer,
209
  hidden_states,
 
221
  )
222
  mixer_kwargs = {"cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen_in_batch}
223
  if subset_mask is None:
224
+ for idx, layer in enumerate(self.layers):
225
+ if self._grad_checkpointing and idx < self._num_checkpointed_layers:
226
  hidden_states = torch.utils.checkpoint.checkpoint(
227
  layer,
228
  hidden_states,
 
233
  hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
234
  hidden_states = pad_input(hidden_states, indices, batch, seqlen)
235
  else:
236
+ for idx, layer in enumerate(self.layers[:-1]):
237
+ if self._grad_checkpointing and idx < self._num_checkpointed_layers:
238
  hidden_states = torch.utils.checkpoint.checkpoint(
239
  layer,
240
  hidden_states,
 
268
  "cu_seqlens_k": cu_seqlens,
269
  "max_seqlen_k": max_seqlen_in_batch,
270
  }
271
+ if self._grad_checkpointing and len(self.layers) <= self._num_checkpointed_layers:
272
  torch.utils.checkpoint.checkpoint(
273
  self.layers[-1],
274
  hidden_states_subset,