Unable to load the model for Torch versions starting from 2.0.1

#34
by benhachem - opened

Hello,

I am encountering an issue while attempting to load the Llama3 8B model using the pipeline function with a bfloat16 dtype, with the latest version of the transformers library. However, I am faced with a runtime error when using the latest torch version (the same problem persists for any torch version starting from 2.0.1).

Code for loading:

# Loading the model
pipe = pipeline("text-generation", model="meta-llama/Meta-Llama-3-8B-Instruct", torch_dtype=torch.bfloat16,  device_map="auto")

RuntimeError: (has something to do with flash attention)

RuntimeError: Failed to import transformers.models.llama.modeling_llama because of the following error (look up to see its traceback):
/databricks/python/lib/python3.10/site-packages/flash_attn_2_cuda.cpython-310-x86_64-linux-gnu.so: undefined symbol: _ZN2at4_ops9_pad_enum4callERKNS_6TensorEN3c108ArrayRefINS5_6SymIntEEElNS5_8optionalIdEE

This issue seems to be resolved when downgrading torch to versions earlier than 2.0.1, but then another issue arises during inference. The torch versions prior to 2.0.1 do not support operations on bfloat16 dtype, which results in the following error:

   1094 if sequence_length != 1:
-> 1095     causal_mask = torch.triu(causal_mask, diagonal=1)
   1096 causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
   1097 causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)

RuntimeError: "triu_tril_cuda_template" not implemented for 'BFloat16'

Could anyone please help me resolve these issues or suggest a workaround? I would greatly appreciate any assistance. Thank you!

I encounter the same issue, though I am using AutoModelForCausalLM function, the error is same as yours. My transformers package version is 4.39.3 and torch version is 2.0.1 .

You'll likely need to update your transformers package to version 4.40.0, which supports Llama 3.

However, we can get arround the error mentioned above by downgrading PyTorch to version 2.0.1. Then, load Llama3-8B in float16 precision, rather than bfloat16. This approach should bypass the error as triu_tril_cuda_template is implemented for the float16 data type in PyTorch 2.0.1 , but note that it doesn't take advantage of the bfloat16 format.

Exactly the same problem as the original post here, except for me with torch==2.0.1 I have the second bug (RuntimeError: "triu_tril_cuda_template" not implemented for 'BFloat16')

I solved my problem by replacing
causal_mask = torch.triu(causal_mask, diagonal=1) with causal_mask = custom_triu(causal_mask), with

def custom_triu(input_tensor):
    rows, cols = input_tensor.shape
    row_indices = torch.arange(rows).unsqueeze(1).expand(rows, cols)
    col_indices = torch.arange(cols).unsqueeze(0).expand(rows, cols)
    mask = row_indices >= col_indices
    output_tensor = input_tensor.clone()
    output_tensor[mask] = 0
    return output_tensor
Meta Llama org

This is a torch break issue, which seems to have been fixed. Upgrading your torch version should be the best bet here 😉 This worked for me in torch 2.3

ArthurZ changed discussion status to closed

Still encounter the issue with torch2.3 transformers 4.40.0

Yes, I have the same issue. For torch 2.3.0 and transformers on 4.40.1 flash attention throws a runtime error.

I happened to solve the issue by uninstall torch and flash attention, then pip install torch==2.2.1 torchvision==0.17.1 torchaudio==2.2.1 --index-url https://download.pytorch.org/whl/cu124 flash-attn --no-build-isolation. The issue was caused by incompatibility. You may want to try different version of torch and cuda regarding your required settings.

Sign up or log in to comment