Bug: FlashAttention forward only supports head dimension

#3
by Xidong - opened
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.fwd(

RuntimeError: FlashAttention forward only supports head dimension at most 256

Zyphra org

This isn't a bug. You can't use Flash Attention on our HF implementation because of our concat before the shared attn layer. We got around this by adding split-head support in our inference stack, which we're working on upstreaming to https://github.com/Zyphra/Zamba-torch

In the meantime, we're going to add an assertion to disable FA2 here until we get the HF port figured out for our FA2 changes. Can you try with non-flash attention?

Ok, get it, Thanks

Xidong changed discussion status to closed

Sign up or log in to comment