"triu_tril_cuda_template" not implemented for 'BFloat16'

#52
by Ashmal - opened

Hello, I am trying to fine-tune this model with the axolotl framework. I'm trying Lora fine-tuning by setting load_in_8_bit=True. I am getting the following error:
Is there a specific reason why the bf16 does not work in this case?

Traceback (most recent call last):
File "/home/ashmal.vayani/anaconda3/envs/axolotl/lib/python3.10/runpy.py", line 196, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/home/ashmal.vayani/anaconda3/envs/axolotl/lib/python3.10/runpy.py", line 86, in _run_code
exec(code, run_globals)
File "/mnt/beegfs/fahad.khan/axolotl/src/axolotl/cli/train.py", line 59, in
fire.Fire(do_cli)
File "/home/ashmal.vayani/anaconda3/envs/axolotl/lib/python3.10/site-packages/fire/core.py", line 143, in Fire
component_trace = _Fire(component, args, parsed_flag_args, context, name)
File "/home/ashmal.vayani/anaconda3/envs/axolotl/lib/python3.10/site-packages/fire/core.py", line 477, in _Fire
component, remaining_args = _CallAndUpdateTrace(
File "/home/ashmal.vayani/anaconda3/envs/axolotl/lib/python3.10/site-packages/fire/core.py", line 693, in _CallAndUpdateTrace
component = fn(*varargs, **kwargs)
File "/mnt/beegfs/fahad.khan/axolotl/src/axolotl/cli/train.py", line 35, in do_cli
return do_train(parsed_cfg, parsed_cli_args)
File "/mnt/beegfs/fahad.khan/axolotl/src/axolotl/cli/train.py", line 55, in do_train
return train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
File "/mnt/beegfs/fahad.khan/axolotl/src/axolotl/train.py", line 163, in train
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
File "/home/ashmal.vayani/anaconda3/envs/axolotl/lib/python3.10/site-packages/transformers/trainer.py", line 1780, in train
return inner_training_loop(
File "/home/ashmal.vayani/anaconda3/envs/axolotl/lib/python3.10/site-packages/transformers/trainer.py", line 2118, in _inner_training_loop
tr_loss_step = self.training_step(model, inputs)
File "/home/ashmal.vayani/anaconda3/envs/axolotl/lib/python3.10/site-packages/transformers/trainer.py", line 3036, in training_step
loss = self.compute_loss(model, inputs)
File "/mnt/beegfs/fahad.khan/axolotl/src/axolotl/core/trainer_builder.py", line 485, in compute_loss
return super().compute_loss(model, inputs, return_outputs=return_outputs)
File "/home/ashmal.vayani/anaconda3/envs/axolotl/lib/python3.10/site-packages/transformers/trainer.py", line 3059, in compute_loss
outputs = model(**inputs)
File "/home/ashmal.vayani/anaconda3/envs/axolotl/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/ashmal.vayani/anaconda3/envs/axolotl/lib/python3.10/site-packages/torch/nn/parallel/distributed.py", line 1156, in forward
output = self._run_ddp_forward(*inputs, **kwargs)
File "/home/ashmal.vayani/anaconda3/envs/axolotl/lib/python3.10/site-packages/torch/nn/parallel/distributed.py", line 1110, in _run_ddp_forward
return module_to_run(*inputs[0], **kwargs[0]) # type: ignore[index]
File "/home/ashmal.vayani/anaconda3/envs/axolotl/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/ashmal.vayani/anaconda3/envs/axolotl/lib/python3.10/site-packages/accelerate/utils/operations.py", line 822, in forward
return model_forward(*args, **kwargs)
File "/home/ashmal.vayani/anaconda3/envs/axolotl/lib/python3.10/site-packages/accelerate/utils/operations.py", line 810, in call
return convert_to_fp32(self.model_forward(*args, **kwargs))
File "/home/ashmal.vayani/anaconda3/envs/axolotl/lib/python3.10/site-packages/torch/amp/autocast_mode.py", line 14, in decorate_autocast
return func(*args, **kwargs)
File "/home/ashmal.vayani/anaconda3/envs/axolotl/lib/python3.10/site-packages/peft/peft_model.py", line 1129, in forward
return self.base_model(
File "/home/ashmal.vayani/anaconda3/envs/axolotl/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/ashmal.vayani/anaconda3/envs/axolotl/lib/python3.10/site-packages/peft/tuners/tuners_utils.py", line 161, in forward
return self.model.forward(*args, **kwargs)
File "/home/ashmal.vayani/anaconda3/envs/axolotl/lib/python3.10/site-packages/accelerate/hooks.py", line 166, in new_forward
output = module._old_forward(*args, **kwargs)
File "/home/ashmal.vayani/anaconda3/envs/axolotl/lib/python3.10/site-packages/transformers/models/cohere/modeling_cohere.py", line 1099, in forward
outputs = self.model(
File "/home/ashmal.vayani/anaconda3/envs/axolotl/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/ashmal.vayani/anaconda3/envs/axolotl/lib/python3.10/site-packages/accelerate/hooks.py", line 166, in new_forward
output = module._old_forward(*args, **kwargs)
File "/home/ashmal.vayani/anaconda3/envs/axolotl/lib/python3.10/site-packages/transformers/models/cohere/modeling_cohere.py", line 889, in forward
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position)
File "/home/ashmal.vayani/anaconda3/envs/axolotl/lib/python3.10/site-packages/transformers/models/cohere/modeling_cohere.py", line 975, in _update_causal_mask
causal_mask = torch.triu(causal_mask, diagonal=1)
RuntimeError: "triu_tril_cuda_template" not implemented for 'BFloat16'

You should upgrade torch version from 2.0.1 to 2.2.0.

It doesn't fully solve the problem for me. By upgrading to torch==2.2.0, I have another bug:

import flash_attn_2_cuda as flash_attn_cuda
ImportError: /fsx/m4/conda_installation/envs/shared-m4-2024-05-03/lib/python3.8/site-packages/flash_attn_2_cuda.cpython-38-x86_64-linux-gnu.so: undefined symbol: _ZN2at4_ops9_pad_enum4callERKNS_6TensorEN3c108ArrayRefINS5_6SymIntEEElNS5_8optionalIdEE

See this discussion: https://huggingface.co/meta-llama/Meta-Llama-3-8B/discussions/34

@HugoLaurencon
You can uninstall flash-attn first and then reinstall it with pip install flash-attn --no-build-isolation --no-cache-dir

model_args['attn_implementation'] = 'flash_attention_2'
model = LlamaForCausalLM.from_pretrained(model_name, **model_args).eval()

adding the flash_attention_2 works for me

Sign up or log in to comment