Training model breaks with flash_attn_2. Error "NameError: name 'index_first_axis' is not defined"

The commit appears to break the code while training.

Error Stack

=> 588 key_layer = index_first_axis(
589 key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
590 )

NameError: name 'index_first_axis' is not defined


I see that the import

from transformers.utils import (

Is importing the is_flash_attn_2_available, but I could not find it in the transformer's library on GitHub.

Because the below condition fails, the index_first_axis does not get imported and we get the error.

if is_flash_attn_2_available():
from flash_attn import flash_attn_func, flash_attn_varlen_func
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa

same error here

same here. the code worked 2 days before, but i did not have enough resources. now it is not working with same error.

NameError: name 'index_first_axis' is not defined

Any Resolution to this? Breaking

I'm not using flash attention. That is the only resolution from my end lol

@pavankumarbalijepalli - any recommended alternatives?

Do not use flash attention as of now. Try traditional fine tuning with lora.

I realized that I have used flash attention while loading the model. you might have done the same thing. please comment or remove the flash attention from your code and restart the session
model = AutoModelForCausalLM.from_pretrained(model_name, device_map='auto',
# attn_implementation = "flash_attention_2",

