Lint flash_attn.py
Browse files
src/axolotl/flash_attn.py
CHANGED
@@ -1,9 +1,10 @@
|
|
|
|
|
|
1 |
# copied from https://github.com/lm-sys/FastChat/blob/main/fastchat/train/llama_flash_attn_monkey_patch.py
|
2 |
|
3 |
-
from typing import
|
4 |
|
5 |
import torch
|
6 |
-
from torch import nn
|
7 |
|
8 |
import transformers
|
9 |
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
|
@@ -14,7 +15,7 @@ from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
|
|
14 |
from flash_attn.bert_padding import unpad_input, pad_input
|
15 |
|
16 |
|
17 |
-
def forward(
|
18 |
self,
|
19 |
hidden_states: torch.Tensor,
|
20 |
attention_mask: Optional[torch.Tensor] = None,
|
@@ -82,6 +83,8 @@ def forward(
|
|
82 |
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
|
83 |
else:
|
84 |
nheads = qkv.shape[-2]
|
|
|
|
|
85 |
x = rearrange(qkv, "b s three h d -> b s (three h d)")
|
86 |
x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask)
|
87 |
x_unpad = rearrange(
|
@@ -104,13 +107,13 @@ def forward(
|
|
104 |
# requires the attention mask to be the same as the key_padding_mask
|
105 |
def _prepare_decoder_attention_mask(
|
106 |
self, attention_mask, input_shape, inputs_embeds, past_key_values_length
|
107 |
-
):
|
108 |
# [bsz, seq_len]
|
109 |
return attention_mask
|
110 |
|
111 |
|
112 |
def replace_llama_attn_with_flash_attn():
|
113 |
-
transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = (
|
114 |
_prepare_decoder_attention_mask
|
115 |
)
|
116 |
transformers.models.llama.modeling_llama.LlamaAttention.forward = forward
|
|
|
1 |
+
"""Flash attention monkey patch for llama model"""
|
2 |
+
|
3 |
# copied from https://github.com/lm-sys/FastChat/blob/main/fastchat/train/llama_flash_attn_monkey_patch.py
|
4 |
|
5 |
+
from typing import Optional, Tuple
|
6 |
|
7 |
import torch
|
|
|
8 |
|
9 |
import transformers
|
10 |
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
|
|
|
15 |
from flash_attn.bert_padding import unpad_input, pad_input
|
16 |
|
17 |
|
18 |
+
def forward( # pylint: disable=too-many-arguments
|
19 |
self,
|
20 |
hidden_states: torch.Tensor,
|
21 |
attention_mask: Optional[torch.Tensor] = None,
|
|
|
83 |
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
|
84 |
else:
|
85 |
nheads = qkv.shape[-2]
|
86 |
+
|
87 |
+
# pylint: disable=invalid-name
|
88 |
x = rearrange(qkv, "b s three h d -> b s (three h d)")
|
89 |
x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask)
|
90 |
x_unpad = rearrange(
|
|
|
107 |
# requires the attention mask to be the same as the key_padding_mask
|
108 |
def _prepare_decoder_attention_mask(
|
109 |
self, attention_mask, input_shape, inputs_embeds, past_key_values_length
|
110 |
+
): # pylint: disable=unused-argument
|
111 |
# [bsz, seq_len]
|
112 |
return attention_mask
|
113 |
|
114 |
|
115 |
def replace_llama_attn_with_flash_attn():
|
116 |
+
transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = ( # pylint: disable=protected-access
|
117 |
_prepare_decoder_attention_mask
|
118 |
)
|
119 |
transformers.models.llama.modeling_llama.LlamaAttention.forward = forward
|