Support for gradient checkpointing and Flash Attention

#14
by sidnb13 - opened

Any plans to support gradient checkpointing and flash attention for training/finetuning? Would be very helpful to get this working on fewer resources.

chenkq pinned discussion

I think FlashAttn is already used under specific conditions. The attention implementation calls PyTorch's scaled_dot_product_attention function which calls into a FlashAttn kernel if some conditions are met. You can actually enforce the use of this kernel for debugging purposes with an appropriate context manager:

with torch.backends.cuda.sdp_kernel(
    enable_flash=True, 
    enable_math=False, 
    enable_mem_efficient=False
):
    model.generate(**input)

Note, that in the referenced code there is a branch which may execute a naive attention implementation so even though you are using enforcing FA use in PyTorch, you would still make sure that the if-statement runs into the first branch.

Knowledge Engineering Group (KEG) & Data Mining at Tsinghua University org

If you upgrade to PyTorch 2.2.0, you should be able to directly use PyTorch integrated with Flash Attention 2.0. Try to follow the instructions in the previous comment, but there's no need to pull a separate branch.

zRzRzRzRzRzRzR changed discussion status to closed

Sign up or log in to comment