Update bert_layers.py
Browse fileson/off triton attention using config
- bert_layers.py +2 -1
bert_layers.py
CHANGED
@@ -126,6 +126,7 @@ class BertUnpadSelfAttention(nn.Module):
|
|
126 |
warnings.warn(
|
127 |
'Unable to import Triton; defaulting MosaicBERT attention implementation to pytorch (this will reduce throughput when using this model).'
|
128 |
)
|
|
|
129 |
|
130 |
def forward(self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor,
|
131 |
max_seqlen_in_batch: int, indices: torch.Tensor,
|
@@ -158,7 +159,7 @@ class BertUnpadSelfAttention(nn.Module):
|
|
158 |
'b s (t h d) -> b s t h d',
|
159 |
t=3,
|
160 |
h=self.num_attention_heads)
|
161 |
-
if self.p_dropout or
|
162 |
# if we have nonzero attention dropout (e.g. during fine-tuning) or no Triton, compute attention in PyTorch
|
163 |
q = qkv[:, :, 0, :, :].permute(0, 2, 1, 3) # b h s d
|
164 |
k = qkv[:, :, 1, :, :].permute(0, 2, 3, 1) # b h d s
|
|
|
126 |
warnings.warn(
|
127 |
'Unable to import Triton; defaulting MosaicBERT attention implementation to pytorch (this will reduce throughput when using this model).'
|
128 |
)
|
129 |
+
self.flash_attn_triton_disabled = (flash_attn_qkvpacked_func is None) or (config.flash_attn_type != 'triton')
|
130 |
|
131 |
def forward(self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor,
|
132 |
max_seqlen_in_batch: int, indices: torch.Tensor,
|
|
|
159 |
'b s (t h d) -> b s t h d',
|
160 |
t=3,
|
161 |
h=self.num_attention_heads)
|
162 |
+
if self.p_dropout or self.flash_attn_triton_disabled:
|
163 |
# if we have nonzero attention dropout (e.g. during fine-tuning) or no Triton, compute attention in PyTorch
|
164 |
q = qkv[:, :, 0, :, :].permute(0, 2, 1, 3) # b h s d
|
165 |
k = qkv[:, :, 1, :, :].permute(0, 2, 3, 1) # b h d s
|