Markus28 commited on
Commit
ed92835
·
1 Parent(s): 03d8e7c

feat: try to monkey-patch index_first_axis

Browse files
Files changed (1) hide show
  1. modeling_bert.py +8 -4
modeling_bert.py CHANGED
@@ -28,12 +28,16 @@ from transformers.models.bert.modeling_bert import (
28
  BaseModelOutputWithPoolingAndCrossAttentions,
29
  BertForPreTrainingOutput,
30
  )
31
- from .patched_padding_bert import index_first_axis
 
 
 
32
  from flash_attn.bert_padding import (
33
  index_first_axis_residual,
34
  pad_input,
35
  unpad_input,
36
  )
 
37
  from flash_attn.modules.block import Block
38
  from flash_attn.modules.embedding import BertEmbeddings
39
  from flash_attn.modules.mha import MHA
@@ -172,14 +176,14 @@ class BertEncoder(nn.Module):
172
  hidden_states = hidden_states[subset_mask]
173
  else:
174
  batch, seqlen = hidden_states.shape[:2]
175
- hidden_states, indices, cu_seqlens, max_seqlen_in_batch = unpad_input(
176
  hidden_states, key_padding_mask
177
  )
178
  mixer_kwargs = {"cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen_in_batch}
179
  if subset_mask is None:
180
  for layer in self.layers:
181
  hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
182
- hidden_states = pad_input(hidden_states, indices, batch, seqlen)
183
  else:
184
  for layer in self.layers[:-1]:
185
  hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
@@ -197,7 +201,7 @@ class BertEncoder(nn.Module):
197
  subset_cu_seqlens = F.pad(
198
  torch.cumsum(subset_seqlens, dim=0, dtype=torch.torch.int32), (1, 0)
199
  )
200
- hidden_states_subset, hidden_states = index_first_axis_residual(
201
  hidden_states, subset_idx
202
  )
203
  # It's ok to set max_seqlen_q to be much larger
 
28
  BaseModelOutputWithPoolingAndCrossAttentions,
29
  BertForPreTrainingOutput,
30
  )
31
+ from .patched_padding_bert import index_first_axis as index_first_axis_monkey_patch
32
+ import flash_attn.bert_padding
33
+ flash_attn.bert_padding.index_first_axis = index_first_axis_monkey_patch
34
+ """
35
  from flash_attn.bert_padding import (
36
  index_first_axis_residual,
37
  pad_input,
38
  unpad_input,
39
  )
40
+ """
41
  from flash_attn.modules.block import Block
42
  from flash_attn.modules.embedding import BertEmbeddings
43
  from flash_attn.modules.mha import MHA
 
176
  hidden_states = hidden_states[subset_mask]
177
  else:
178
  batch, seqlen = hidden_states.shape[:2]
179
+ hidden_states, indices, cu_seqlens, max_seqlen_in_batch = flash_attn.bert_padding.unpad_input(
180
  hidden_states, key_padding_mask
181
  )
182
  mixer_kwargs = {"cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen_in_batch}
183
  if subset_mask is None:
184
  for layer in self.layers:
185
  hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
186
+ hidden_states = flash_attn.bert_padding.pad_input(hidden_states, indices, batch, seqlen)
187
  else:
188
  for layer in self.layers[:-1]:
189
  hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
 
201
  subset_cu_seqlens = F.pad(
202
  torch.cumsum(subset_seqlens, dim=0, dtype=torch.torch.int32), (1, 0)
203
  )
204
+ hidden_states_subset, hidden_states = flash_attn.bert_padding.index_first_axis_residual(
205
  hidden_states, subset_idx
206
  )
207
  # It's ok to set max_seqlen_q to be much larger