Can this be fine-tuned with triton backed flash attention and alibi using the huggingface transformers trainer?

#13
by winglian - opened

I gave it a shot and wanted to make sure it was even possible before going down a rabbit hole. thanks!

  File "/root/.cache/huggingface/modules/transformers_modules/mosaicml/mpt-7b/d8304854d4877849c3c0a78f3469512a84419e84/attention.py", line 171, in forward                                                                                                                             
    (context, attn_weights) = self.attn_fn(query, key, value, self.n_heads, softmax_scale=self.softmax_scale, attn_bias=attn_bias, key_padding_mask=key_padding_mask, is_causal=is_causal, dropout_p=self.attn_dropout_p, training=self.training, needs_weights=needs_weights)         
  File "/root/.cache/huggingface/modules/transformers_modules/mosaicml/mpt-7b/d8304854d4877849c3c0a78f3469512a84419e84/attention.py", line 111, in triton_flash_attn_fn                                                                                                                
    attn_output = flash_attn_triton.flash_attn_func(query, key, value, attn_bias, reset_is_causal, softmax_scale)
  File "/root/miniconda3/envs/py3.9/lib/python3.9/site-packages/torch/autograd/function.py", line 506, in apply                                                                                                                                                                        
    return super().apply(*args, **kwargs)  # type: ignore[misc]                                                                                                                                                                                                                        
  File "/root/miniconda3/envs/py3.9/lib/python3.9/site-packages/flash_attn/flash_attn_triton.py", line 810, in forward                     
    o, lse, ctx.softmax_scale = _flash_attn_forward(                                                                                       
  File "/root/miniconda3/envs/py3.9/lib/python3.9/site-packages/flash_attn/flash_attn_triton.py", line 599, in _flash_attn_forward
    assert bias.dtype in [q.dtype, torch.float]
Mosaic ML, Inc. org

Hi @winglian , we have not tested finetuning with the HF Trainer so I can't guaranteed compatibility.

You can find instructions for fine-tuning with Composer and our LLM Foundry codebase here: https://github.com/mosaicml/llm-foundry/tree/main/scripts/train#llm-finetuning. We are committed to maintaining this repo for community and customers, and you can file Github issues directly there!

abhi-mosaic changed discussion status to closed

Sign up or log in to comment