FalconForCausalLM does not support Flash Attention 2.0 yet

#98
by Menouar - opened

Dear Repository Owners,
The Falcon model loaded from the library supports Flash Attention:

from transformers import FalconForCausalLM
model = FalconForCausalLM.from_pretrained(
    model_id,
    trust_remote_code=True,
    device_map="auto",
    attn_implementation="flash_attention_2",
    torch_dtype=torch.bfloat16,
    quantization_config=bnb_config
)

However, the Falcon model loaded from the hub does not support it:

from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    trust_remote_code=True,
    device_map="auto",
    attn_implementation="flash_attention_2",
    torch_dtype=torch.bfloat16,
    quantization_config=bnb_config
)

I encountered the following error:

ValueError: FalconForCausalLM does not support Flash Attention 2.0 yet.

This discrepancy seems to occur because the model was originally hosted on the hub and was later incorporated into this library.
@Rocketknight1

Thanks

Sign up or log in to comment