|
""" |
|
Flash attention monkey patch for cerebras btlm model |
|
""" |
|
|
|
import importlib |
|
import logging |
|
from typing import Optional, Tuple |
|
|
|
import torch |
|
from flash_attn.flash_attn_interface import flash_attn_func |
|
from transformers import AutoConfig, AutoModelForCausalLM |
|
|
|
LOG = logging.getLogger("axolotl") |
|
|
|
|
|
def replace_btlm_attn_with_flash_attn(model_name="cerebras/btlm-3b-8k-base"): |
|
|
|
model_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) |
|
|
|
AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True) |
|
module_name = model_config.__class__.__module__.replace( |
|
".configuration_btlm", ".modeling_btlm" |
|
) |
|
modeling_btlm = importlib.import_module(module_name) |
|
modeling_btlm.BTLMAttention._attn = ( |
|
flashattn_attn |
|
) |
|
|
|
|
|
def flashattn_attn( |
|
self, |
|
query: torch.Tensor, |
|
key: Optional[torch.Tensor] = None, |
|
value: Optional[torch.Tensor] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
head_mask: Optional[torch.Tensor] = None, |
|
position_bias: Optional[torch.Tensor] = None, |
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: |
|
softmax_scale = ( |
|
1 / (key.size(-1) ** self.attn_scale_power) if self.scale_attn_weights else None |
|
) |
|
|
|
query = query.permute(0, 2, 1, 3) |
|
key = key.permute(0, 2, 1, 3) |
|
value = value.permute(0, 2, 1, 3) |
|
|
|
|
|
attn_output = flash_attn_func( |
|
query, |
|
key, |
|
value, |
|
dropout_p=0.0, |
|
softmax_scale=softmax_scale, |
|
causal=not self.is_cross_attention, |
|
return_attn_probs=False, |
|
) |
|
|
|
|
|
if head_mask is not None: |
|
attn_output *= head_mask |
|
|
|
attn_output = attn_output.permute(0, 2, 1, 3) |
|
|
|
return attn_output, None |
|
|