Markus28 commited on
Commit
d4d5621
1 Parent(s): 75d7a16

feat: added back option not to use flash attention

Browse files
Files changed (2) hide show
  1. configuration_bert.py +2 -0
  2. modeling_bert.py +4 -2
configuration_bert.py CHANGED
@@ -81,6 +81,7 @@ class JinaBertConfig(PretrainedConfig):
81
  fused_dropout_add_ln=False,
82
  fused_bias_fc=False,
83
  pad_vocab_size_multiple=1,
 
84
  **kwargs,
85
  ):
86
  assert 'position_embedding_type' not in kwargs
@@ -106,3 +107,4 @@ class JinaBertConfig(PretrainedConfig):
106
  self.fused_dropout_add_ln = fused_dropout_add_ln
107
  self.fused_bias_fc = fused_bias_fc
108
  self.pad_vocab_size_multiple = pad_vocab_size_multiple
 
 
81
  fused_dropout_add_ln=False,
82
  fused_bias_fc=False,
83
  pad_vocab_size_multiple=1,
84
+ use_flash_attn=True,
85
  **kwargs,
86
  ):
87
  assert 'position_embedding_type' not in kwargs
 
107
  self.fused_dropout_add_ln = fused_dropout_add_ln
108
  self.fused_bias_fc = fused_bias_fc
109
  self.pad_vocab_size_multiple = pad_vocab_size_multiple
110
+ self.use_flash_attn = use_flash_attn
modeling_bert.py CHANGED
@@ -59,6 +59,7 @@ logger = logging.getLogger(__name__)
59
 
60
 
61
  def create_mixer_cls(config, cross_attn=False, return_residual=False):
 
62
  fused_bias_fc = getattr(config, "fused_bias_fc", False)
63
  window_size = getattr(config, "window_size", (-1, -1))
64
  mixer_cls = partial(
@@ -68,7 +69,7 @@ def create_mixer_cls(config, cross_attn=False, return_residual=False):
68
  dropout=config.attention_probs_dropout_prob,
69
  causal=False,
70
  fused_bias_fc=fused_bias_fc,
71
- use_flash_attn=True,
72
  return_residual=return_residual,
73
  use_alibi=True,
74
  window_size=window_size,
@@ -151,6 +152,7 @@ def _init_weights(module, initializer_range=0.02):
151
  class BertEncoder(nn.Module):
152
  def __init__(self, config: JinaBertConfig):
153
  super().__init__()
 
154
  self.layers = nn.ModuleList(
155
  [create_block(config, layer_idx=i) for i in range(config.num_hidden_layers)]
156
  )
@@ -171,7 +173,7 @@ class BertEncoder(nn.Module):
171
  This means that we only compute the last layer output for these tokens.
172
  subset_mask: (batch, seqlen), dtype=torch.bool
173
  """
174
- if key_padding_mask is None:
175
  mixer_kwargs = (
176
  {"key_padding_mask": key_padding_mask} if key_padding_mask is not None else None
177
  )
 
59
 
60
 
61
  def create_mixer_cls(config, cross_attn=False, return_residual=False):
62
+ use_flash_attn = getattr(config, "use_flash_attn", False)
63
  fused_bias_fc = getattr(config, "fused_bias_fc", False)
64
  window_size = getattr(config, "window_size", (-1, -1))
65
  mixer_cls = partial(
 
69
  dropout=config.attention_probs_dropout_prob,
70
  causal=False,
71
  fused_bias_fc=fused_bias_fc,
72
+ use_flash_attn=use_flash_attn,
73
  return_residual=return_residual,
74
  use_alibi=True,
75
  window_size=window_size,
 
152
  class BertEncoder(nn.Module):
153
  def __init__(self, config: JinaBertConfig):
154
  super().__init__()
155
+ self.use_flash_attn = getattr(config, "use_flash_attn", False)
156
  self.layers = nn.ModuleList(
157
  [create_block(config, layer_idx=i) for i in range(config.num_hidden_layers)]
158
  )
 
173
  This means that we only compute the last layer output for these tokens.
174
  subset_mask: (batch, seqlen), dtype=torch.bool
175
  """
176
+ if key_padding_mask is None or not self.use_flash_attn:
177
  mixer_kwargs = (
178
  {"key_padding_mask": key_padding_mask} if key_padding_mask is not None else None
179
  )