SmallQ FlashAttention for Ascend 910B
面向推测解码(Speculative Decoding)/ MTP / decode-time attention 场景的 AscendC 自定义算子。当 Query 序列长度很小(典型 1–8 token)、KV 长度很长时,本算子比通用 FlashAttention 更省 UB、更好并行。
- 硬件:Ascend 910B3 (ascend910b) / Atlas 800I A2
- CANN 版本:8.5.0+
- 数据类型:fp16(内部 fp32 累加)
- 算子注册名:
SmallqFlashAttention,aclnn APIaclnnSmallqFlashAttention
接口
| 张量 | Shape | 说明 |
|---|---|---|
| Q (input) | [numHeads, qLen, headDim] |
fp16 |
| K (input) | [numHeads, kvLen, headDim] |
fp16 |
| V (input) | [numHeads, kvLen, headDim] |
fp16 |
| O (output) | [numHeads, qLen, headDim] |
fp16 |
O = softmax(Q · Kᵀ / √headDim) · V
支持范围
| 参数 | 范围 |
|---|---|
| numHeads | 任意 ≥ 1 |
| qLen | 任意 ≥ 1(典型 1–8) |
| headDim | 任意 ≥ 1(无 16 对齐要求) |
| kvLen | 任意 ≥ 1 |
实测在 7 张 910B3 上并行验证了 256 用例 × 0 失败,覆盖:
- Llama 3.1 8B / Qwen 3-8B (nh=32, hd=128)
- Llama 3 70B / Qwen 3-32B (nh=64, hd=128)
- Llama 3.1 405B (nh=128, hd=128)
- Qwen 3-235B-A22B (nh=64, hd=128)
- DeepSeek V3 (MLA K-side hd=192 / V-side hd=128)
- Mistral / Gemma (nh=32, hd=128)
- 边界 head_dim:64, 256
- qLen ∈ {1, 2, 4, 8}
- kvLen 从 1 到 16384,含 1023/1025/1500/3000 等非 2ⁿ 值
算法要点
- K-outer 分块 + Online Softmax:blockK 默认 64,UB 紧张时按 ÷2 自动降到下限 4
- Q 驻留 UB:qLen 小,整个 Q 一次性搬入并常驻
- 双缓冲预取 K:搬入与计算流水并行
- fp32 标量 QK 点积 + fp32 标量 PV 累加:避开 fp16 累加误差,对任意 headDim 都正确
- 任意 headDim 支持的关键:UB 内
hdPad = ceil(headDim/16)*16对齐;输出回 GM 时压紧到 headDim stride,非 16 对齐尾部用DataCopyPadUB→GM 字节级写出(硬件指令copy_ubuf_to_gm_align_b16使用字节级写使能,避免多 head 共享 cache line 的写竞争) - 多核并行:每个 head 一个 AI Core,blockDim = numHeads
文件结构
smallq_flash_attention/ # 算子源码(独立可编译)
├── README.md
├── CMakeLists.txt
├── op_kernel/
│ ├── smallq_flash_attention.cpp
│ └── smallq_flash_attention_impl.h
└── op_host/
├── smallq_flash_attention_def.cpp
├── smallq_flash_attention_proto.cpp
├── smallq_flash_attention_tiling.h
└── smallq_flash_attention_tiling.cpp
tests/ # 验证程序
├── CMakeLists.txt
├── test_aclnn.cpp # aclnn 调用样例(单 case)
└── run_model_tests.py # 多卡并行模型 shape 测试驱动
编译部署
将 smallq_flash_attention/ 目录放入 cann-recipes-infer 项目的 ops/ascendc/src/ 下,然后:
cd cann-recipes-infer/ops/ascendc
bash build.sh -n "smallq_flash_attention" -c "ascend910b"
# 部署(必须装到 opp 路径下,opp 优先级高于 vendors)
bash output/CANN-custom_ops-none-linux.aarch64.run \
--quiet --install-path=$ASCEND_HOME_PATH/opp/vendors/customize
yes | cp -rf $ASCEND_HOME_PATH/opp/vendors/customize/vendors/customize/. \
$ASCEND_HOME_PATH/opp/vendors/customize/
rm -rf $ASCEND_HOME_PATH/opp/vendors/customize/vendors
aclnn API 调用
#include "aclnn_smallq_flash_attention.h"
aclTensor *qTensor, *kTensor, *vTensor, *oTensor;
// ... 创建 fp16 aclTensor,shape [numHeads, qLen|kvLen, headDim]
uint64_t workspaceSize = 0;
aclOpExecutor* executor = nullptr;
aclnnSmallqFlashAttentionGetWorkspaceSize(
qTensor, kTensor, vTensor, oTensor, &workspaceSize, &executor);
void* workspace = nullptr;
if (workspaceSize > 0) aclrtMalloc(&workspace, workspaceSize, ACL_MEM_MALLOC_HUGE_FIRST);
aclnnSmallqFlashAttention(workspace, workspaceSize, executor, stream);
aclrtSynchronizeStream(stream);
测试
# 编译测试程序
cd tests
cmake -B build -DCANN_PATH=$ASCEND_HOME_PATH
cmake --build build -j
# 单 case
NUM_HEADS=32 Q_LEN=1 HEAD_DIM=128 KV_LEN=4096 DEVICE_ID=1 IO_DIR=/tmp/io \
./build/test_aclnn
# 多卡并行扫描真实模型 shape(默认 NPU 1–7)
python3 run_model_tests.py
License
Apache-2.0
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support