RuntimeError: FlashAttention only support fp16 and bf16 data type during fine tuning.

#11
by faizsameerahmed96 - opened

The hyper params i am using

training_config = {
    "bf16": True,
    "do_eval": False,
    "learning_rate": 0.00001,
    "lr_scheduler_type": "cosine",
    "log_level": "info",
    "logging_steps": 30,
    "logging_strategy": "steps",
    "num_train_epochs": 5,
    "max_steps": -1,
    "output_dir": "./workspace/checkpoint_dir",
    "overwrite_output_dir": True,
    "per_device_eval_batch_size": 4,
    "remove_unused_columns": True,
    "save_steps": 100,
    "save_total_limit": 1,
    "seed": 0,
    "gradient_checkpointing": True,
    "gradient_checkpointing_kwargs":{"use_reentrant": False},
    "gradient_accumulation_steps": 1,
    "warmup_ratio": 0.2,
}

I am loading the model using

checkpoint_path = "microsoft/Phi-3-small-8k-instruct"
model_kwargs = dict(
    use_cache=False,
    trust_remote_code=True,
    attn_implementation="flash_attention_2",
    torch_dtype="auto",
    device_map=None,
)
model = AutoModelForCausalLM.from_pretrained(checkpoint_path, **model_kwargs)

And start training

trainer = SFTTrainer(
        model=model,
        args=train_conf,
        train_dataset=processed_dataset,
        max_seq_length=8192,
        dataset_text_field="text",
        tokenizer=tokenizer,
        packing=True
    )
train_result = trainer.train()

I am getting the following error

    107 # if out.isnan().any() or softmax_lse.isnan().any():
    108 #     breakpoint()
    109 return out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state

RuntimeError: FlashAttention only support fp16 and bf16 data type

I used the exact same config while fine tuning phi-3-mini-128k without any issues. Is anyone else facing the same issue?

Microsoft org

Hi !
Flash attention as well as the block-sparse kernel for attention require the model to be trained on fp16 / bf16. Is there a reason why bfloat16 might not work for your use-case ?

Hi,

Getting the same error, even with bf16 = True in training arg

Not sure if it's the correct fix. Here is how I make it work
https://huggingface.co/microsoft/Phi-3-small-8k-instruct/blob/f5527db8a43fc9a4bf17c5b754251e1efe1d4ad3/positional_embedding.py#L269
update the dtype of the q and k after the rotary mapping:

return (
            apply_rotary_pos_emb(
                q, cos_cached[seqlen_offset:seq_len], sin_cached[seqlen_offset:seq_len], seq_dimension=seq_dimension
            ).to(q.dtype),
            apply_rotary_pos_emb(
                k, cos_cached[seqlen_offset:seq_len], sin_cached[seqlen_offset:seq_len], seq_dimension=seq_dimension
            ).to(q.dtype),
        )

Thx for the answer.
Happens when device_map = "auto" (or anything that is not None). Might be a problem related with flash attention and multi gpus training. If you have a fix, do not hesitate.

I will be doing another batch of training over the weekend, will try out @ecocytus11 solution. Thanks!

Facing the same issue, with 8k and 128k small model

Sign up or log in to comment