feat: cleave off layers from encoder

#11
Files changed (1) hide show
  1. modeling_bert.py +23 -4
modeling_bert.py CHANGED
@@ -166,6 +166,25 @@ class BertEncoder(nn.Module):
166
  [create_block(config, layer_idx=i) for i in range(config.num_hidden_layers)]
167
  )
168
  self._grad_checkpointing = False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
 
170
  @property
171
  def gradient_checkpointing(self):
@@ -186,7 +205,7 @@ class BertEncoder(nn.Module):
186
  mixer_kwargs = (
187
  {"key_padding_mask": key_padding_mask.bool()} if key_padding_mask is not None else None
188
  )
189
- for layer in self.layers:
190
  hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
191
  if subset_mask is not None:
192
  hidden_states = hidden_states[subset_mask]
@@ -197,11 +216,11 @@ class BertEncoder(nn.Module):
197
  )
198
  mixer_kwargs = {"cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen_in_batch}
199
  if subset_mask is None:
200
- for layer in self.layers:
201
  hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
202
  hidden_states = pad_input(hidden_states, indices, batch, seqlen)
203
  else:
204
- for layer in self.layers[:-1]:
205
  hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
206
  if key_padding_mask is not None:
207
  subset_idx = torch.nonzero(
@@ -228,7 +247,7 @@ class BertEncoder(nn.Module):
228
  "cu_seqlens_k": cu_seqlens,
229
  "max_seqlen_k": max_seqlen_in_batch,
230
  }
231
- hidden_states = self.layers[-1](hidden_states_subset, mixer_kwargs=mixer_kwargs)
232
  return hidden_states
233
 
234
 
 
166
  [create_block(config, layer_idx=i) for i in range(config.num_hidden_layers)]
167
  )
168
  self._grad_checkpointing = False
169
+ self._last_layer_idx = len(self.layers) - 1
170
+
171
+ @property
172
+ def last_layer_idx(self):
173
+ return self._last_layer_idx
174
+
175
+ @last_layer_idx.setter
176
+ def last_layer_idx(self, idx: int):
177
+ assert 0 <= idx < len(self.layers)
178
+ self._last_layer_idx = idx
179
+
180
+ @property
181
+ def cleaved_layers(self):
182
+ return len(self.layers) - self.last_layer_idx - 1
183
+
184
+ @cleaved_layers.setter
185
+ def cleaved_layers(self, n: int):
186
+ assert 0 <= n < len(self.layers)
187
+ self.last_layer_idx = len(self.layers) - n - 1
188
 
189
  @property
190
  def gradient_checkpointing(self):
 
205
  mixer_kwargs = (
206
  {"key_padding_mask": key_padding_mask.bool()} if key_padding_mask is not None else None
207
  )
208
+ for layer in self.layers[:self.last_layer_idx + 1]:
209
  hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
210
  if subset_mask is not None:
211
  hidden_states = hidden_states[subset_mask]
 
216
  )
217
  mixer_kwargs = {"cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen_in_batch}
218
  if subset_mask is None:
219
+ for layer in self.layers[:self.last_layer_idx + 1]:
220
  hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
221
  hidden_states = pad_input(hidden_states, indices, batch, seqlen)
222
  else:
223
+ for layer in self.layers[:self.last_layer_idx]:
224
  hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
225
  if key_padding_mask is not None:
226
  subset_idx = torch.nonzero(
 
247
  "cu_seqlens_k": cu_seqlens,
248
  "max_seqlen_k": max_seqlen_in_batch,
249
  }
250
+ hidden_states = self.layers[self.last_layer_idx](hidden_states_subset, mixer_kwargs=mixer_kwargs)
251
  return hidden_states
252
 
253