feat: updated activation checkpointing

#14
Files changed (1) hide show
  1. modeling_bert.py +38 -7
modeling_bert.py CHANGED
@@ -81,7 +81,8 @@ def create_mixer_cls(config, cross_attn=False, return_residual=False):
81
  return_residual=return_residual,
82
  use_alibi=True,
83
  window_size=window_size,
84
- qk_norm=use_qk_norm
 
85
  )
86
  return mixer_cls
87
 
@@ -174,8 +175,6 @@ class BertEncoder(nn.Module):
174
  @gradient_checkpointing.setter
175
  def gradient_checkpointing(self, value):
176
  self._grad_checkpointing = value
177
- for block in self.layers:
178
- block.mixer.checkpointing = value
179
 
180
  def forward(self, hidden_states, key_padding_mask=None, subset_mask=None):
181
  """If subset_mask is not None, we only want output for the subset of the sequence.
@@ -187,7 +186,15 @@ class BertEncoder(nn.Module):
187
  {"key_padding_mask": key_padding_mask.bool()} if key_padding_mask is not None else None
188
  )
189
  for layer in self.layers:
190
- hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
 
 
 
 
 
 
 
 
191
  if subset_mask is not None:
192
  hidden_states = hidden_states[subset_mask]
193
  else:
@@ -198,11 +205,27 @@ class BertEncoder(nn.Module):
198
  mixer_kwargs = {"cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen_in_batch}
199
  if subset_mask is None:
200
  for layer in self.layers:
201
- hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
 
 
 
 
 
 
 
 
202
  hidden_states = pad_input(hidden_states, indices, batch, seqlen)
203
  else:
204
  for layer in self.layers[:-1]:
205
- hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
 
 
 
 
 
 
 
 
206
  if key_padding_mask is not None:
207
  subset_idx = torch.nonzero(
208
  subset_mask[key_padding_mask], as_tuple=False
@@ -228,7 +251,15 @@ class BertEncoder(nn.Module):
228
  "cu_seqlens_k": cu_seqlens,
229
  "max_seqlen_k": max_seqlen_in_batch,
230
  }
231
- hidden_states = self.layers[-1](hidden_states_subset, mixer_kwargs=mixer_kwargs)
 
 
 
 
 
 
 
 
232
  return hidden_states
233
 
234
 
 
81
  return_residual=return_residual,
82
  use_alibi=True,
83
  window_size=window_size,
84
+ qk_norm=use_qk_norm,
85
+ checkpointing=False,
86
  )
87
  return mixer_cls
88
 
 
175
  @gradient_checkpointing.setter
176
  def gradient_checkpointing(self, value):
177
  self._grad_checkpointing = value
 
 
178
 
179
  def forward(self, hidden_states, key_padding_mask=None, subset_mask=None):
180
  """If subset_mask is not None, we only want output for the subset of the sequence.
 
186
  {"key_padding_mask": key_padding_mask.bool()} if key_padding_mask is not None else None
187
  )
188
  for layer in self.layers:
189
+ if self._grad_checkpointing:
190
+ hidden_states = torch.utils.checkpoint.checkpoint(
191
+ layer,
192
+ hidden_states,
193
+ use_reentrant=False,
194
+ mixer_kwargs=mixer_kwargs
195
+ )
196
+ else:
197
+ hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
198
  if subset_mask is not None:
199
  hidden_states = hidden_states[subset_mask]
200
  else:
 
205
  mixer_kwargs = {"cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen_in_batch}
206
  if subset_mask is None:
207
  for layer in self.layers:
208
+ if self._grad_checkpointing:
209
+ hidden_states = torch.utils.checkpoint.checkpoint(
210
+ layer,
211
+ hidden_states,
212
+ use_reentrant=False,
213
+ mixer_kwargs=mixer_kwargs
214
+ )
215
+ else:
216
+ hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
217
  hidden_states = pad_input(hidden_states, indices, batch, seqlen)
218
  else:
219
  for layer in self.layers[:-1]:
220
+ if self._grad_checkpointing:
221
+ hidden_states = torch.utils.checkpoint.checkpoint(
222
+ layer,
223
+ hidden_states,
224
+ use_reentrant=False,
225
+ mixer_kwargs=mixer_kwargs
226
+ )
227
+ else:
228
+ hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
229
  if key_padding_mask is not None:
230
  subset_idx = torch.nonzero(
231
  subset_mask[key_padding_mask], as_tuple=False
 
251
  "cu_seqlens_k": cu_seqlens,
252
  "max_seqlen_k": max_seqlen_in_batch,
253
  }
254
+ if self._grad_checkpointing:
255
+ torch.utils.checkpoint.checkpoint(
256
+ self.layers[-1],
257
+ hidden_states_subset,
258
+ use_reentrant=False,
259
+ mixer_kwargs=mixer_kwargs
260
+ )
261
+ else:
262
+ hidden_states = self.layers[-1](hidden_states_subset, mixer_kwargs=mixer_kwargs)
263
  return hidden_states
264
 
265