Training model breaks with flash_attn_2. Error "NameError: name 'index_first_axis' is not defined"

#105
by praveeny - opened

The commit https://huggingface.co/microsoft/phi-2/commit/eb8bbd1d37d258ea74fb082c53346d33056a83d4 appears to break the code while training.

Error Stack

=> 588 key_layer = index_first_axis(
589 key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
590 )

NameError: name 'index_first_axis' is not defined

Investigation

I see that the import

from transformers.utils import (
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_flash_attn_2_available,
is_flash_attn_greater_or_equal_2_10,
logging,
replace_return_docstrings,
)

Is importing the is_flash_attn_2_available, but I could not find it in the transformer's library on GitHub.

Because the below condition fails, the index_first_axis does not get imported and we get the error.

if is_flash_attn_2_available():
from flash_attn import flash_attn_func, flash_attn_varlen_func
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa

same error here

same here. the code worked 2 days before, but i did not have enough resources. now it is not working with same error.

NameError: name 'index_first_axis' is not defined

Any Resolution to this? Breaking

Any Resolution to this? Breaking

I'm not using flash attention. That is the only resolution from my end lol

@pavankumarbalijepalli - any recommended alternatives?

@pavankumarbalijepalli - any recommended alternatives?

Do not use flash attention as of now. Try traditional fine tuning with lora.

I realized that I have used flash attention while loading the model. you might have done the same thing. please comment or remove the flash attention from your code and restart the session
model = AutoModelForCausalLM.from_pretrained(model_name, device_map='auto',
quantization_config=bnb_config,
# attn_implementation = "flash_attention_2",
trust_remote_code=True)

Sign up or log in to comment