Markus28 commited on
Commit
3160695
1 Parent(s): ed92835

feat: reverted monkey patch

Browse files
Files changed (2) hide show
  1. configuration_bert.py +0 -2
  2. modeling_bert.py +5 -17
configuration_bert.py CHANGED
@@ -14,8 +14,6 @@
14
  # See the License for the specific language governing permissions and
15
  # limitations under the License.
16
  """ BERT model configuration"""
17
- from collections import OrderedDict
18
- from typing import Mapping
19
 
20
  from transformers import PretrainedConfig
21
 
 
14
  # See the License for the specific language governing permissions and
15
  # limitations under the License.
16
  """ BERT model configuration"""
 
 
17
 
18
  from transformers import PretrainedConfig
19
 
modeling_bert.py CHANGED
@@ -28,16 +28,13 @@ from transformers.models.bert.modeling_bert import (
28
  BaseModelOutputWithPoolingAndCrossAttentions,
29
  BertForPreTrainingOutput,
30
  )
31
- from .patched_padding_bert import index_first_axis as index_first_axis_monkey_patch
32
- import flash_attn.bert_padding
33
- flash_attn.bert_padding.index_first_axis = index_first_axis_monkey_patch
34
- """
35
  from flash_attn.bert_padding import (
 
36
  index_first_axis_residual,
37
  pad_input,
38
  unpad_input,
39
  )
40
- """
41
  from flash_attn.modules.block import Block
42
  from flash_attn.modules.embedding import BertEmbeddings
43
  from flash_attn.modules.mha import MHA
@@ -176,14 +173,14 @@ class BertEncoder(nn.Module):
176
  hidden_states = hidden_states[subset_mask]
177
  else:
178
  batch, seqlen = hidden_states.shape[:2]
179
- hidden_states, indices, cu_seqlens, max_seqlen_in_batch = flash_attn.bert_padding.unpad_input(
180
  hidden_states, key_padding_mask
181
  )
182
  mixer_kwargs = {"cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen_in_batch}
183
  if subset_mask is None:
184
  for layer in self.layers:
185
  hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
186
- hidden_states = flash_attn.bert_padding.pad_input(hidden_states, indices, batch, seqlen)
187
  else:
188
  for layer in self.layers[:-1]:
189
  hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
@@ -201,7 +198,7 @@ class BertEncoder(nn.Module):
201
  subset_cu_seqlens = F.pad(
202
  torch.cumsum(subset_seqlens, dim=0, dtype=torch.torch.int32), (1, 0)
203
  )
204
- hidden_states_subset, hidden_states = flash_attn.bert_padding.index_first_axis_residual(
205
  hidden_states, subset_idx
206
  )
207
  # It's ok to set max_seqlen_q to be much larger
@@ -425,15 +422,6 @@ class BertModel(BertPreTrainedModel):
425
  pooler_output=pooled_output,
426
  )
427
 
428
- def to(self, *args, **kwargs):
429
- print(f'In BERT, calling to({args, kwargs})')
430
- result = super().to(*args, **kwargs)
431
- if (len(args) > 0 and isinstance(args[0], torch.dtype)) or "dtype" in kwargs:
432
- for layer in result.encoder.layers:
433
- layer.mixer.inner_cross_attn.alibi_slopes = layer.mixer.inner_cross_attn.alibi_slopes.to(torch.float32)
434
- layer.mixer.inner_attn.alibi_slopes = layer.mixer.inner_attn.alibi_slopes.to(torch.float32)
435
- return result
436
-
437
 
438
  class BertForPreTraining(BertPreTrainedModel):
439
  def __init__(self, config: JinaBertConfig):
 
28
  BaseModelOutputWithPoolingAndCrossAttentions,
29
  BertForPreTrainingOutput,
30
  )
 
 
 
 
31
  from flash_attn.bert_padding import (
32
+ index_first_axis,
33
  index_first_axis_residual,
34
  pad_input,
35
  unpad_input,
36
  )
37
+
38
  from flash_attn.modules.block import Block
39
  from flash_attn.modules.embedding import BertEmbeddings
40
  from flash_attn.modules.mha import MHA
 
173
  hidden_states = hidden_states[subset_mask]
174
  else:
175
  batch, seqlen = hidden_states.shape[:2]
176
+ hidden_states, indices, cu_seqlens, max_seqlen_in_batch = unpad_input(
177
  hidden_states, key_padding_mask
178
  )
179
  mixer_kwargs = {"cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen_in_batch}
180
  if subset_mask is None:
181
  for layer in self.layers:
182
  hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
183
+ hidden_states = pad_input(hidden_states, indices, batch, seqlen)
184
  else:
185
  for layer in self.layers[:-1]:
186
  hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
 
198
  subset_cu_seqlens = F.pad(
199
  torch.cumsum(subset_seqlens, dim=0, dtype=torch.torch.int32), (1, 0)
200
  )
201
+ hidden_states_subset, hidden_states = index_first_axis_residual(
202
  hidden_states, subset_idx
203
  )
204
  # It's ok to set max_seqlen_q to be much larger
 
422
  pooler_output=pooled_output,
423
  )
424
 
 
 
 
 
 
 
 
 
 
425
 
426
  class BertForPreTraining(BertPreTrainedModel):
427
  def __init__(self, config: JinaBertConfig):