Markus28 commited on
Commit
2e2b8d0
1 Parent(s): fabeb13

feat: choose flash attention heuristically if not set explicitly

Browse files
Files changed (1) hide show
  1. modeling_bert.py +2 -2
modeling_bert.py CHANGED
@@ -66,7 +66,7 @@ logger = logging.getLogger(__name__)
66
 
67
 
68
  def create_mixer_cls(config, cross_attn=False, return_residual=False):
69
- use_flash_attn = config.use_flash_attn
70
  use_qk_norm = config.use_qk_norm
71
  fused_bias_fc = config.fused_bias_fc
72
  window_size = config.window_size
@@ -161,7 +161,7 @@ def _init_weights(module, initializer_range=0.02):
161
  class BertEncoder(nn.Module):
162
  def __init__(self, config: JinaBertConfig):
163
  super().__init__()
164
- self.use_flash_attn = getattr(config, "use_flash_attn", False)
165
  self.layers = nn.ModuleList(
166
  [create_block(config, layer_idx=i) for i in range(config.num_hidden_layers)]
167
  )
 
66
 
67
 
68
  def create_mixer_cls(config, cross_attn=False, return_residual=False):
69
+ use_flash_attn = config.use_flash_attn if config.use_flash_attn is not None else torch.cuda.is_available()
70
  use_qk_norm = config.use_qk_norm
71
  fused_bias_fc = config.fused_bias_fc
72
  window_size = config.window_size
 
161
  class BertEncoder(nn.Module):
162
  def __init__(self, config: JinaBertConfig):
163
  super().__init__()
164
+ self.use_flash_attn = config.use_flash_attn if config.use_flash_attn is not None else torch.cuda.is_available()
165
  self.layers = nn.ModuleList(
166
  [create_block(config, layer_idx=i) for i in range(config.num_hidden_layers)]
167
  )