Do you really use flash attention?

#5
by GinnM - opened

I noticed that:

        attn = scaled_dot_product_attention(
            query=xq.transpose(1, 2),
            key=xk.transpose(1, 2),
            value=xv.transpose(1, 2),
            attn_mask=attention_mask.bool(),
            dropout_p=0,
        ).transpose(1, 2)

But in the scenario that the attn_mask parameter is not None, scaled_dot_product_attention will not use flash attention actually.

Your need to confirm your account before you can post a new comment.

Sign up or log in to comment