llm_mutil_npu / tests /test_rope_fused.cpp
xianglarry's picture
Initial C++ aclnn EAGER inference for Qwen3-235B-A22B MoE on Ascend 910 × 16 NPU
4b9fefd
// test_rope_fused.cpp — test aclnnApplyRotaryPosEmbV2 vs our manual 8-op HF RoPE.
// If rotaryMode="half" matches HF, we can replace apply_rope_manual with 1 op → 7× reduction
// of per-layer op count for RoPE phase.
#include "acl_common.h"
#include "acl_runtime.h"
#include "aclnn_ops.h"
#include "rope.h"
#include "engine.h" // for fill_cos_sin_hf + RopeCache
#include <aclnnop/aclnn_apply_rotary_pos_emb_v2.h>
#include <cmath>
#include <cstdio>
#include <cstring>
#include <vector>
static float bf16_to_float(uint16_t x) { uint32_t u = (uint32_t)x << 16; float f; std::memcpy(&f, &u, 4); return f; }
static uint16_t f_to_bf16(float f) { uint32_t u; std::memcpy(&u, &f, 4); return (uint16_t)((u + 0x7FFF + ((u >> 16) & 1)) >> 16); }
int main() {
AclRuntime rt;
rt.init(0);
// Test shape: 1 batch, 5 seq, 4 heads, head_dim=128 (Qwen3-like)
const int64_t B = 1, S = 5, Hq = 4, Hkv = 4, Dh = 128;
const float theta = 5e6f; // Qwen3 theta
// Random q, k (deterministic from seed)
std::vector<uint16_t> h_q(B * S * Hq * Dh), h_k(B * S * Hkv * Dh);
uint32_t seed = 42;
auto rnd = [&seed]() {
seed = seed * 1103515245 + 12345;
return f_to_bf16(((seed >> 16) / 32768.0f - 1.0f) * 0.1f);
};
for (auto& x : h_q) x = rnd();
for (auto& x : h_k) x = rnd();
// cos/sin cache (positions 0..S-1)
std::vector<uint16_t> cos_h, sin_h;
fill_cos_sin_hf(cos_h, sin_h, 0, S, Dh, theta);
DeviceBuffer q1(h_q.size() * 2), k1(h_k.size() * 2);
DeviceBuffer q2(h_q.size() * 2), k2(h_k.size() * 2);
DeviceBuffer cos_dev(cos_h.size() * 2), sin_dev(sin_h.size() * 2);
DeviceBuffer scratch(B * S * Hq * Dh * 2);
ACL_CHECK(aclrtMemcpy(q1.get(), h_q.size()*2, h_q.data(), h_q.size()*2, ACL_MEMCPY_HOST_TO_DEVICE));
ACL_CHECK(aclrtMemcpy(q2.get(), h_q.size()*2, h_q.data(), h_q.size()*2, ACL_MEMCPY_HOST_TO_DEVICE));
ACL_CHECK(aclrtMemcpy(k1.get(), h_k.size()*2, h_k.data(), h_k.size()*2, ACL_MEMCPY_HOST_TO_DEVICE));
ACL_CHECK(aclrtMemcpy(k2.get(), h_k.size()*2, h_k.data(), h_k.size()*2, ACL_MEMCPY_HOST_TO_DEVICE));
ACL_CHECK(aclrtMemcpy(cos_dev.get(), cos_h.size()*2, cos_h.data(), cos_h.size()*2, ACL_MEMCPY_HOST_TO_DEVICE));
ACL_CHECK(aclrtMemcpy(sin_dev.get(), sin_h.size()*2, sin_h.data(), sin_h.size()*2, ACL_MEMCPY_HOST_TO_DEVICE));
// --- Path 1: our manual HF RoPE ---
apply_rope_manual(rt.stream(), q1.get(), B, S, Hq, Dh, k1.get(), Hkv,
cos_dev.get(), sin_dev.get(), scratch.get());
rt.sync();
std::vector<uint16_t> q1_out(h_q.size()), k1_out(h_k.size());
ACL_CHECK(aclrtMemcpy(q1_out.data(), h_q.size()*2, q1.get(), h_q.size()*2, ACL_MEMCPY_DEVICE_TO_HOST));
ACL_CHECK(aclrtMemcpy(k1_out.data(), h_k.size()*2, k1.get(), h_k.size()*2, ACL_MEMCPY_DEVICE_TO_HOST));
// --- Path 2: aclnnApplyRotaryPosEmbV2 with rotaryMode="half" ---
// Layout: see docs. Common: 0=BSND, 1=SBND, 2=BNSD. q/k shape [B, S, N, Dh].
// cos/sin shape: typically [1, S, 1, Dh] or [S, Dh].
// Try multiple combinations until one works
struct Try { int64_t layout; const char* mode; std::vector<int64_t> qshape; std::vector<int64_t> cshape; };
std::vector<Try> tries = {
{0, "half", {B, S, Hq, Dh}, {1, S, 1, Dh}},
{1, "half", {B, S, Hq, Dh}, {1, S, 1, Dh}},
{2, "half", {B, Hq, S, Dh}, {1, 1, S, Dh}},
{0, "half", {B, S, Hq, Dh}, {S, Dh}},
{0, "interleaved", {B, S, Hq, Dh}, {1, S, 1, Dh}},
{0, "half", {S, Hq, Dh}, {S, 1, Dh}},
};
uint64_t ws = 0; aclOpExecutor* exec = nullptr;
aclnnStatus s1 = -1;
Try chosen{};
for (auto& t : tries) {
auto t_q = make_contig_tensor(q2.get(), ACL_BF16, t.qshape);
std::vector<int64_t> kshape = t.qshape; if (kshape.size() >= 3) kshape[kshape.size()-2] = Hkv;
auto t_k = make_contig_tensor(k2.get(), ACL_BF16, kshape);
auto t_cos = make_contig_tensor(cos_dev.get(), ACL_BF16, t.cshape);
auto t_sin = make_contig_tensor(sin_dev.get(), ACL_BF16, t.cshape);
char buf[32]; strncpy(buf, t.mode, sizeof(buf));
s1 = aclnnApplyRotaryPosEmbV2GetWorkspaceSize(t_q.get(), t_k.get(), t_cos.get(), t_sin.get(),
t.layout, buf, &ws, &exec);
printf("[ropev2] layout=%ld mode=%-12s qshape=%zu cshape=%zu → status=%d\n",
t.layout, t.mode, t.qshape.size(), t.cshape.size(), (int)s1);
if (s1 == 0) { chosen = t; break; }
}
if (s1 != 0) { fprintf(stderr, "All combos failed\n"); return 1; }
printf("→ winning: layout=%ld mode=%s\n", chosen.layout, chosen.mode);
DeviceBuffer wb; if (ws > 0) wb.alloc(ws);
s1 = aclnnApplyRotaryPosEmbV2(wb.get(), ws, exec, rt.stream());
printf("[ropev2] exec: status=%d\n", (int)s1);
if (s1 != 0) return 1;
rt.sync();
std::vector<uint16_t> q2_out(h_q.size()), k2_out(h_k.size());
ACL_CHECK(aclrtMemcpy(q2_out.data(), h_q.size()*2, q2.get(), h_q.size()*2, ACL_MEMCPY_DEVICE_TO_HOST));
ACL_CHECK(aclrtMemcpy(k2_out.data(), h_k.size()*2, k2.get(), h_k.size()*2, ACL_MEMCPY_DEVICE_TO_HOST));
// Compare
double q_l2d = 0, q_l2r = 0, q_max = 0;
for (size_t i = 0; i < h_q.size(); i++) {
float a = bf16_to_float(q1_out[i]), b = bf16_to_float(q2_out[i]);
q_l2d += (a-b)*(a-b); q_l2r += a*a;
if (std::abs(a-b) > q_max) q_max = std::abs(a-b);
}
double q_rel = std::sqrt(q_l2d) / (std::sqrt(q_l2r) + 1e-10);
double k_l2d = 0, k_l2r = 0, k_max = 0;
for (size_t i = 0; i < h_k.size(); i++) {
float a = bf16_to_float(k1_out[i]), b = bf16_to_float(k2_out[i]);
k_l2d += (a-b)*(a-b); k_l2r += a*a;
if (std::abs(a-b) > k_max) k_max = std::abs(a-b);
}
double k_rel = std::sqrt(k_l2d) / (std::sqrt(k_l2r) + 1e-10);
printf("\nManual-HF vs aclnnApplyRotaryPosEmbV2(layout=0, mode=half):\n");
printf(" Q: rel=%.4e max=%.4f\n", q_rel, q_max);
printf(" K: rel=%.4e max=%.4f\n", k_rel, k_max);
printf(" Q[0,:4] manual: %.5f %.5f %.5f %.5f\n",
bf16_to_float(q1_out[0]), bf16_to_float(q1_out[1]),
bf16_to_float(q1_out[2]), bf16_to_float(q1_out[3]));
printf(" Q[0,:4] ropev2: %.5f %.5f %.5f %.5f\n",
bf16_to_float(q2_out[0]), bf16_to_float(q2_out[1]),
bf16_to_float(q2_out[2]), bf16_to_float(q2_out[3]));
bool pass = q_rel < 1e-2 && k_rel < 1e-2;
printf("\n%s\n", pass ? "=== RoPE V2 matches manual HF ===" : "=== MISMATCH — need different mode/layout ===");
return pass ? 0 : 1;
}