m10an commited on
Commit
d5ca709
·
verified ·
1 Parent(s): 550be4b

Update bert_layers.py

Browse files

on/off triton attention using config

Files changed (1) hide show
  1. bert_layers.py +2 -1
bert_layers.py CHANGED
@@ -126,6 +126,7 @@ class BertUnpadSelfAttention(nn.Module):
126
  warnings.warn(
127
  'Unable to import Triton; defaulting MosaicBERT attention implementation to pytorch (this will reduce throughput when using this model).'
128
  )
 
129
 
130
  def forward(self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor,
131
  max_seqlen_in_batch: int, indices: torch.Tensor,
@@ -158,7 +159,7 @@ class BertUnpadSelfAttention(nn.Module):
158
  'b s (t h d) -> b s t h d',
159
  t=3,
160
  h=self.num_attention_heads)
161
- if self.p_dropout or flash_attn_qkvpacked_func is None:
162
  # if we have nonzero attention dropout (e.g. during fine-tuning) or no Triton, compute attention in PyTorch
163
  q = qkv[:, :, 0, :, :].permute(0, 2, 1, 3) # b h s d
164
  k = qkv[:, :, 1, :, :].permute(0, 2, 3, 1) # b h d s
 
126
  warnings.warn(
127
  'Unable to import Triton; defaulting MosaicBERT attention implementation to pytorch (this will reduce throughput when using this model).'
128
  )
129
+ self.flash_attn_triton_disabled = (flash_attn_qkvpacked_func is None) or (config.flash_attn_type != 'triton')
130
 
131
  def forward(self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor,
132
  max_seqlen_in_batch: int, indices: torch.Tensor,
 
159
  'b s (t h d) -> b s t h d',
160
  t=3,
161
  h=self.num_attention_heads)
162
+ if self.p_dropout or self.flash_attn_triton_disabled:
163
  # if we have nonzero attention dropout (e.g. during fine-tuning) or no Triton, compute attention in PyTorch
164
  q = qkv[:, :, 0, :, :].permute(0, 2, 1, 3) # b h s d
165
  k = qkv[:, :, 1, :, :].permute(0, 2, 3, 1) # b h d s