| |
| |
| |
| #include "acl_common.h" |
| #include "acl_runtime.h" |
| #include "aclnn_ops.h" |
| #include "rope.h" |
| #include "engine.h" |
|
|
| #include <aclnnop/aclnn_apply_rotary_pos_emb_v2.h> |
|
|
| #include <cmath> |
| #include <cstdio> |
| #include <cstring> |
| #include <vector> |
|
|
| static float bf16_to_float(uint16_t x) { uint32_t u = (uint32_t)x << 16; float f; std::memcpy(&f, &u, 4); return f; } |
| static uint16_t f_to_bf16(float f) { uint32_t u; std::memcpy(&u, &f, 4); return (uint16_t)((u + 0x7FFF + ((u >> 16) & 1)) >> 16); } |
|
|
| int main() { |
| AclRuntime rt; |
| rt.init(0); |
|
|
| |
| const int64_t B = 1, S = 5, Hq = 4, Hkv = 4, Dh = 128; |
| const float theta = 5e6f; |
|
|
| |
| std::vector<uint16_t> h_q(B * S * Hq * Dh), h_k(B * S * Hkv * Dh); |
| uint32_t seed = 42; |
| auto rnd = [&seed]() { |
| seed = seed * 1103515245 + 12345; |
| return f_to_bf16(((seed >> 16) / 32768.0f - 1.0f) * 0.1f); |
| }; |
| for (auto& x : h_q) x = rnd(); |
| for (auto& x : h_k) x = rnd(); |
|
|
| |
| std::vector<uint16_t> cos_h, sin_h; |
| fill_cos_sin_hf(cos_h, sin_h, 0, S, Dh, theta); |
|
|
| DeviceBuffer q1(h_q.size() * 2), k1(h_k.size() * 2); |
| DeviceBuffer q2(h_q.size() * 2), k2(h_k.size() * 2); |
| DeviceBuffer cos_dev(cos_h.size() * 2), sin_dev(sin_h.size() * 2); |
| DeviceBuffer scratch(B * S * Hq * Dh * 2); |
|
|
| ACL_CHECK(aclrtMemcpy(q1.get(), h_q.size()*2, h_q.data(), h_q.size()*2, ACL_MEMCPY_HOST_TO_DEVICE)); |
| ACL_CHECK(aclrtMemcpy(q2.get(), h_q.size()*2, h_q.data(), h_q.size()*2, ACL_MEMCPY_HOST_TO_DEVICE)); |
| ACL_CHECK(aclrtMemcpy(k1.get(), h_k.size()*2, h_k.data(), h_k.size()*2, ACL_MEMCPY_HOST_TO_DEVICE)); |
| ACL_CHECK(aclrtMemcpy(k2.get(), h_k.size()*2, h_k.data(), h_k.size()*2, ACL_MEMCPY_HOST_TO_DEVICE)); |
| ACL_CHECK(aclrtMemcpy(cos_dev.get(), cos_h.size()*2, cos_h.data(), cos_h.size()*2, ACL_MEMCPY_HOST_TO_DEVICE)); |
| ACL_CHECK(aclrtMemcpy(sin_dev.get(), sin_h.size()*2, sin_h.data(), sin_h.size()*2, ACL_MEMCPY_HOST_TO_DEVICE)); |
|
|
| |
| apply_rope_manual(rt.stream(), q1.get(), B, S, Hq, Dh, k1.get(), Hkv, |
| cos_dev.get(), sin_dev.get(), scratch.get()); |
| rt.sync(); |
|
|
| std::vector<uint16_t> q1_out(h_q.size()), k1_out(h_k.size()); |
| ACL_CHECK(aclrtMemcpy(q1_out.data(), h_q.size()*2, q1.get(), h_q.size()*2, ACL_MEMCPY_DEVICE_TO_HOST)); |
| ACL_CHECK(aclrtMemcpy(k1_out.data(), h_k.size()*2, k1.get(), h_k.size()*2, ACL_MEMCPY_DEVICE_TO_HOST)); |
|
|
| |
| |
| |
|
|
| |
| struct Try { int64_t layout; const char* mode; std::vector<int64_t> qshape; std::vector<int64_t> cshape; }; |
| std::vector<Try> tries = { |
| {0, "half", {B, S, Hq, Dh}, {1, S, 1, Dh}}, |
| {1, "half", {B, S, Hq, Dh}, {1, S, 1, Dh}}, |
| {2, "half", {B, Hq, S, Dh}, {1, 1, S, Dh}}, |
| {0, "half", {B, S, Hq, Dh}, {S, Dh}}, |
| {0, "interleaved", {B, S, Hq, Dh}, {1, S, 1, Dh}}, |
| {0, "half", {S, Hq, Dh}, {S, 1, Dh}}, |
| }; |
|
|
| uint64_t ws = 0; aclOpExecutor* exec = nullptr; |
| aclnnStatus s1 = -1; |
| Try chosen{}; |
| for (auto& t : tries) { |
| auto t_q = make_contig_tensor(q2.get(), ACL_BF16, t.qshape); |
| std::vector<int64_t> kshape = t.qshape; if (kshape.size() >= 3) kshape[kshape.size()-2] = Hkv; |
| auto t_k = make_contig_tensor(k2.get(), ACL_BF16, kshape); |
| auto t_cos = make_contig_tensor(cos_dev.get(), ACL_BF16, t.cshape); |
| auto t_sin = make_contig_tensor(sin_dev.get(), ACL_BF16, t.cshape); |
| char buf[32]; strncpy(buf, t.mode, sizeof(buf)); |
| s1 = aclnnApplyRotaryPosEmbV2GetWorkspaceSize(t_q.get(), t_k.get(), t_cos.get(), t_sin.get(), |
| t.layout, buf, &ws, &exec); |
| printf("[ropev2] layout=%ld mode=%-12s qshape=%zu cshape=%zu → status=%d\n", |
| t.layout, t.mode, t.qshape.size(), t.cshape.size(), (int)s1); |
| if (s1 == 0) { chosen = t; break; } |
| } |
| if (s1 != 0) { fprintf(stderr, "All combos failed\n"); return 1; } |
| printf("→ winning: layout=%ld mode=%s\n", chosen.layout, chosen.mode); |
| DeviceBuffer wb; if (ws > 0) wb.alloc(ws); |
| s1 = aclnnApplyRotaryPosEmbV2(wb.get(), ws, exec, rt.stream()); |
| printf("[ropev2] exec: status=%d\n", (int)s1); |
| if (s1 != 0) return 1; |
| rt.sync(); |
|
|
| std::vector<uint16_t> q2_out(h_q.size()), k2_out(h_k.size()); |
| ACL_CHECK(aclrtMemcpy(q2_out.data(), h_q.size()*2, q2.get(), h_q.size()*2, ACL_MEMCPY_DEVICE_TO_HOST)); |
| ACL_CHECK(aclrtMemcpy(k2_out.data(), h_k.size()*2, k2.get(), h_k.size()*2, ACL_MEMCPY_DEVICE_TO_HOST)); |
|
|
| |
| double q_l2d = 0, q_l2r = 0, q_max = 0; |
| for (size_t i = 0; i < h_q.size(); i++) { |
| float a = bf16_to_float(q1_out[i]), b = bf16_to_float(q2_out[i]); |
| q_l2d += (a-b)*(a-b); q_l2r += a*a; |
| if (std::abs(a-b) > q_max) q_max = std::abs(a-b); |
| } |
| double q_rel = std::sqrt(q_l2d) / (std::sqrt(q_l2r) + 1e-10); |
|
|
| double k_l2d = 0, k_l2r = 0, k_max = 0; |
| for (size_t i = 0; i < h_k.size(); i++) { |
| float a = bf16_to_float(k1_out[i]), b = bf16_to_float(k2_out[i]); |
| k_l2d += (a-b)*(a-b); k_l2r += a*a; |
| if (std::abs(a-b) > k_max) k_max = std::abs(a-b); |
| } |
| double k_rel = std::sqrt(k_l2d) / (std::sqrt(k_l2r) + 1e-10); |
|
|
| printf("\nManual-HF vs aclnnApplyRotaryPosEmbV2(layout=0, mode=half):\n"); |
| printf(" Q: rel=%.4e max=%.4f\n", q_rel, q_max); |
| printf(" K: rel=%.4e max=%.4f\n", k_rel, k_max); |
| printf(" Q[0,:4] manual: %.5f %.5f %.5f %.5f\n", |
| bf16_to_float(q1_out[0]), bf16_to_float(q1_out[1]), |
| bf16_to_float(q1_out[2]), bf16_to_float(q1_out[3])); |
| printf(" Q[0,:4] ropev2: %.5f %.5f %.5f %.5f\n", |
| bf16_to_float(q2_out[0]), bf16_to_float(q2_out[1]), |
| bf16_to_float(q2_out[2]), bf16_to_float(q2_out[3])); |
|
|
| bool pass = q_rel < 1e-2 && k_rel < 1e-2; |
| printf("\n%s\n", pass ? "=== RoPE V2 matches manual HF ===" : "=== MISMATCH — need different mode/layout ==="); |
| return pass ? 0 : 1; |
| } |
|
|