undefined symbol error
I am getting an error just trying to run this model, can anyone help me identify what this issue is and how to fix it. I am using this command in (databricks environment) just to load the model:
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
tokenizer = AutoTokenizer.from_pretrained("google/gemma-7b-it") # google/gemma-2b-it
model = AutoModelForCausalLM.from_pretrained("google/gemma-7b-it", device_map="auto", torch_dtype=torch.bfloat16)
However when it is trying to load I recieve this error: Failed to import transformers.models.gemma.modeling_gemma because of the following error (look up to see its traceback):
/databricks/python/lib/python3.10/site-packages/flash_attn_2_cuda.cpython-310-x86_64-linux-gnu.so: undefined symbol: _ZN2at4_ops9_pad_enum4callERKNS_6TensorEN3c108ArrayRefINS5_6SymIntEEElNS5_8optionalIdEE
I am installing the libraries: transformers 4.38.1, accelerate 0.23.0, torch2.2.1+cu121
Any help to understand what this error means exactly and why it is called or if konw how to fix. I cannot seem to find anything to help
Hi
@Cgodwin
You might be getting that because of a weird interaction between torch 2.2.1 and FA2. They recently added FA2 support for SDPA maybe that's why.
can you try to downgrade torch to 2.1.2 or alternatively upgrade flash-attn-2 to its latest version?
If you confirm it's a torch
issue, can you post an issue on their repository: https://github.com/pytorch/pytorch
It's insane this is not written anywhere and you solved the issue. Thanks @ybelkada