flash attention 2

#32
by NickyNicky - opened

GPU A100 40GB (COLAB)


!python -c "import torch; assert torch.cuda.get_device_capability()[0] >= 8, 'Hardware not supported for Flash Attention'" # pass
!export CUDA_HOME=/usr/local/cuda-11.8
# !MAX_JOBS=4 pip install flash-attn --no-build-isolation
!MAX_JOBS=4 pip install flash-attn --no-build-isolation  -qqq
!pip install git+"https://github.com/HazyResearch/flash-attention.git#subdirectory=csrc/rotary" -qqq
!python -m pip install optimum -qqq

import torch, transformers,torchvision
torch.__version__,transformers.__version__, torchvision.__version__ # ('2.0.1+cu118', '4.34.0.dev0', '0.15.2+cu118')

model_id = 'mistralai/Mistral-7B-Instruct-v0.1'

model = AutoModelForCausalLM.from_pretrained(model_id,
                                             device_map="auto",
                                             trust_remote_code=True,
                                             torch_dtype=torch.bfloat16,
                                             load_in_4bit=True,
                                             quantization_config=quantization_config,
                                             use_flash_attention_2=True,
                                             low_cpu_mem_usage= True,

                                             )

ERROR
```Python

ValueError Traceback (most recent call last)
in <cell line: 32>()
30 # from optimum.bettertransformer import BetterTransformer #flash attention 2
31
---> 32 model = AutoModelForCausalLM.from_pretrained(model_id,
33 device_map="auto",
34 trust_remote_code=True,

2 frames
/usr/local/lib/python3.10/dist-packages/transformers/modeling_utils.py in _check_and_enable_flash_attn_2(cls, config, torch_dtype, device_map)
1263 """
1264 if not cls._supports_flash_attn_2:
-> 1265 raise ValueError(
1266 "The current architecture does not support Flash Attention 2.0. Please open an issue on GitHub to "
1267 "request support for this architecture: https://github.com/huggingface/transformers/issues/new"

ValueError: The current architecture does not support Flash Attention 2.0. Please open an issue on GitHub to request support for this architecture: https://github.com/huggingface/transformers/issues/new
```

Other ERROR BetterTransformer -->> flash attention 2

!python -c "import torch; assert torch.cuda.get_device_capability()[0] >= 8, 'Hardware not supported for Flash Attention'"
!export CUDA_HOME=/usr/local/cuda-11.8
# !MAX_JOBS=4 pip install flash-attn --no-build-isolation
!MAX_JOBS=4 pip install flash-attn --no-build-isolation  -qqq
!pip install git+"https://github.com/HazyResearch/flash-attention.git#subdirectory=csrc/rotary" -qqq
!python -m pip install optimum -qqq


from optimum.bettertransformer import BetterTransformer #flash attention 2

model = AutoModelForCausalLM.from_pretrained(model_id,
                                             device_map="auto",
                                             trust_remote_code=True,
                                             torch_dtype=torch.bfloat16,
                                             load_in_4bit=True,
                                             quantization_config=quantization_config,
                                             # use_flash_attention_2=True,
                                             low_cpu_mem_usage= True,
                                             )

# model = BetterTransformer.transform(model, keep_original_model=False) #flash attention 2

ERROR 2
```Python

NotImplementedError Traceback (most recent call last)
in <cell line: 43>()
41 )
42
---> 43 model = BetterTransformer.transform(model, keep_original_model=False) #flash attention 2
44
45

1 frames
/usr/local/lib/python3.10/dist-packages/optimum/bettertransformer/transformation.py in transform(model, keep_original_model, max_memory, offload_dir, **kwargs)
226 )
227 if not BetterTransformerManager.supports(model.config.model_type):
--> 228 raise NotImplementedError(
229 f"The model type {model.config.model_type} is not yet supported to be used with BetterTransformer. Feel free"
230 f" to open an issue at https://github.com/huggingface/optimum/issues if you would like this model type to be supported."

NotImplementedError: The model type mistral is not yet supported to be used with BetterTransformer. Feel free to open an issue at https://github.com/huggingface/optimum/issues
if you would like this model type to be supported. Currently supported models are:
dict_keys(['albert', 'bark', 'bart', 'bert', 'bert-generation',
'blenderbot', 'bloom', 'camembert', 'blip-2', 'clip', 'codegen', 'data2vec-text', 'deit',
'distilbert', 'electra', 'ernie', 'fsmt', 'falcon', 'gpt2', 'gpt_bigcode', 'gptj', 'gpt_neo',
'gpt_neox', 'hubert', 'layoutlm', 'llama', 'm2m_100', 'marian', 'markuplm', 'mbart',
'opt', 'pegasus', 'rembert', 'prophetnet', 'roberta', 'roc_bert', 'roformer', 'splinter',
'tapas', 't5', 'vilt', 'vit', 'vit_mae', 'vit_msn', 'wav2vec2', 'whisper', 'xlm-roberta', 'yolos']).
```

Getting the same error. In the supported models they say that Mistral is included but don't know why is it giving this error.

Hi @NickyNicky and @sgauravm

If you install the latest version of transformers

pip install -U transformers

Flash Attention-2 should be supported

Check out this specific section of the docs: https://huggingface.co/docs/transformers/model_doc/mistral#combining-mistral-and-flash-attention-2 for more details

Thank you very much,
I have the latest version of transformers.

@NickyNicky
Your first script

!python -c "import torch; assert torch.cuda.get_device_capability()[0] >= 8, 'Hardware not supported for Flash Attention'" # pass
!export CUDA_HOME=/usr/local/cuda-11.8
# !MAX_JOBS=4 pip install flash-attn --no-build-isolation
!MAX_JOBS=4 pip install flash-attn --no-build-isolation  -qqq
!pip install git+"https://github.com/HazyResearch/flash-attention.git#subdirectory=csrc/rotary" -qqq
!python -m pip install optimum -qqq

import torch, transformers,torchvision
torch.__version__,transformers.__version__, torchvision.__version__ # ('2.0.1+cu118', '4.34.0.dev0', '0.15.2+cu118')

model_id = 'mistralai/Mistral-7B-Instruct-v0.1'

model = AutoModelForCausalLM.from_pretrained(model_id,
                                             device_map="auto",
                                             trust_remote_code=True,
                                             torch_dtype=torch.bfloat16,
                                             load_in_4bit=True,
                                             quantization_config=quantization_config,
                                             use_flash_attention_2=True,
                                             low_cpu_mem_usage= True,

                                             )

Should work if you have latest transformers installed, however Mistral is not in BetterTransformer yet, we will add the support of F.SDPA natively in transformers core soon

With these versions it works.

import torch, transformers
torch.__version__,transformers.__version__
('2.0.1+cu118', '4.34.0')
NickyNicky changed discussion status to closed

Sign up or log in to comment