SmallQ FlashAttention for Ascend 910B

面向推测解码(Speculative Decoding)/ MTP / decode-time attention 场景的 AscendC 自定义算子。当 Query 序列长度很小(典型 1–8 token)、KV 长度很长时,本算子比通用 FlashAttention 更省 UB、更好并行。

  • 硬件:Ascend 910B3 (ascend910b) / Atlas 800I A2
  • CANN 版本:8.5.0+
  • 数据类型:fp16(内部 fp32 累加)
  • 算子注册名SmallqFlashAttention,aclnn API aclnnSmallqFlashAttention

接口

张量 Shape 说明
Q (input) [numHeads, qLen, headDim] fp16
K (input) [numHeads, kvLen, headDim] fp16
V (input) [numHeads, kvLen, headDim] fp16
O (output) [numHeads, qLen, headDim] fp16
O = softmax(Q · Kᵀ / √headDim) · V

支持范围

参数 范围
numHeads 任意 ≥ 1
qLen 任意 ≥ 1(典型 1–8)
headDim 任意 ≥ 1(无 16 对齐要求)
kvLen 任意 ≥ 1

实测在 7 张 910B3 上并行验证了 256 用例 × 0 失败,覆盖:

  • Llama 3.1 8B / Qwen 3-8B (nh=32, hd=128)
  • Llama 3 70B / Qwen 3-32B (nh=64, hd=128)
  • Llama 3.1 405B (nh=128, hd=128)
  • Qwen 3-235B-A22B (nh=64, hd=128)
  • DeepSeek V3 (MLA K-side hd=192 / V-side hd=128)
  • Mistral / Gemma (nh=32, hd=128)
  • 边界 head_dim:64, 256
  • qLen ∈ {1, 2, 4, 8}
  • kvLen 从 1 到 16384,含 1023/1025/1500/3000 等非 2ⁿ 值

算法要点

  • K-outer 分块 + Online Softmax:blockK 默认 64,UB 紧张时按 ÷2 自动降到下限 4
  • Q 驻留 UB:qLen 小,整个 Q 一次性搬入并常驻
  • 双缓冲预取 K:搬入与计算流水并行
  • fp32 标量 QK 点积 + fp32 标量 PV 累加:避开 fp16 累加误差,对任意 headDim 都正确
  • 任意 headDim 支持的关键:UB 内 hdPad = ceil(headDim/16)*16 对齐;输出回 GM 时压紧到 headDim stride,非 16 对齐尾部用 DataCopyPad UB→GM 字节级写出(硬件指令 copy_ubuf_to_gm_align_b16 使用字节级写使能,避免多 head 共享 cache line 的写竞争)
  • 多核并行:每个 head 一个 AI Core,blockDim = numHeads

文件结构

smallq_flash_attention/      # 算子源码(独立可编译)
├── README.md
├── CMakeLists.txt
├── op_kernel/
│   ├── smallq_flash_attention.cpp
│   └── smallq_flash_attention_impl.h
└── op_host/
    ├── smallq_flash_attention_def.cpp
    ├── smallq_flash_attention_proto.cpp
    ├── smallq_flash_attention_tiling.h
    └── smallq_flash_attention_tiling.cpp

tests/                       # 验证程序
├── CMakeLists.txt
├── test_aclnn.cpp           # aclnn 调用样例(单 case)
└── run_model_tests.py       # 多卡并行模型 shape 测试驱动

编译部署

smallq_flash_attention/ 目录放入 cann-recipes-infer 项目的 ops/ascendc/src/ 下,然后:

cd cann-recipes-infer/ops/ascendc
bash build.sh -n "smallq_flash_attention" -c "ascend910b"

# 部署(必须装到 opp 路径下,opp 优先级高于 vendors)
bash output/CANN-custom_ops-none-linux.aarch64.run \
     --quiet --install-path=$ASCEND_HOME_PATH/opp/vendors/customize
yes | cp -rf $ASCEND_HOME_PATH/opp/vendors/customize/vendors/customize/. \
             $ASCEND_HOME_PATH/opp/vendors/customize/
rm -rf $ASCEND_HOME_PATH/opp/vendors/customize/vendors

aclnn API 调用

#include "aclnn_smallq_flash_attention.h"

aclTensor *qTensor, *kTensor, *vTensor, *oTensor;
// ... 创建 fp16 aclTensor,shape [numHeads, qLen|kvLen, headDim]

uint64_t workspaceSize = 0;
aclOpExecutor* executor = nullptr;
aclnnSmallqFlashAttentionGetWorkspaceSize(
    qTensor, kTensor, vTensor, oTensor, &workspaceSize, &executor);

void* workspace = nullptr;
if (workspaceSize > 0) aclrtMalloc(&workspace, workspaceSize, ACL_MEM_MALLOC_HUGE_FIRST);
aclnnSmallqFlashAttention(workspace, workspaceSize, executor, stream);
aclrtSynchronizeStream(stream);

测试

# 编译测试程序
cd tests
cmake -B build -DCANN_PATH=$ASCEND_HOME_PATH
cmake --build build -j

# 单 case
NUM_HEADS=32 Q_LEN=1 HEAD_DIM=128 KV_LEN=4096 DEVICE_ID=1 IO_DIR=/tmp/io \
  ./build/test_aclnn

# 多卡并行扫描真实模型 shape(默认 NPU 1–7)
python3 run_model_tests.py

License

Apache-2.0

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support