FlashAttention support for Mistral HF Implementation

#17
by mxxtsai - opened

Hi,

First of all, thank you all for releasing such an amazing model! I'm trying to further train Mistral-7B-v0.1 on some custom data.

I noticed that the official implementation (https://github.com/mistralai/mistral-src/blob/main/mistral/model.py) has Flash Attention built in.

However, the HuggingFace version (https://github.com/huggingface/transformers/blob/main/src/transformers/models/mistral/modeling_mistral.py) doesn't seem to have Flash Attention integrated.

Is it possible if you can provide a script so that we can replace standard attention with Flash Attention after we've loaded the model via

model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1")

Thank you for your time and effort :)

its still WIP, but I used this seems to work fine for FA2 https://github.com/huggingface/transformers/pull/26464

Sign up or log in to comment