| |
| #include "acl_common.h" |
| #include "acl_runtime.h" |
|
|
| #include <aclnnop/aclnn_apply_rotary_pos_emb.h> |
| #include <aclnnop/aclnn_apply_rotary_pos_emb_v2.h> |
| #include <cmath> |
| #include <cstdio> |
| #include <cstring> |
| #include <fstream> |
| #include <vector> |
|
|
| static float bf16_to_float(uint16_t x) { |
| uint32_t u = (uint32_t)x << 16; float f; std::memcpy(&f, &u, 4); return f; |
| } |
|
|
| static std::vector<uint8_t> read_file(const std::string& p) { |
| std::ifstream f(p, std::ios::binary | std::ios::ate); size_t s = f.tellg(); |
| f.seekg(0); std::vector<uint8_t> v(s); f.read((char*)v.data(), s); return v; |
| } |
|
|
| int main() { |
| const std::string data = "tests/attn_data"; |
|
|
| AclRuntime rt; |
| rt.init(0); |
|
|
| |
| auto qn_h = read_file(data + "/q_normed.bin"); |
| auto kn_h = read_file(data + "/k_normed.bin"); |
| auto cos_h = read_file(data + "/cos.bin"); |
| auto sin_h = read_file(data + "/sin.bin"); |
| auto qr_h = read_file(data + "/q_roped.bin"); |
| auto kr_h = read_file(data + "/k_roped.bin"); |
|
|
| |
| const int64_t S = 5, Hq = 64, Hkv = 4, Dh = 128; |
|
|
| DeviceBuffer q_d(qn_h.size()), k_d(kn_h.size()), cos_d(cos_h.size()), sin_d(sin_h.size()); |
| ACL_CHECK(aclrtMemcpy(q_d.get(), qn_h.size(), qn_h.data(), qn_h.size(), ACL_MEMCPY_HOST_TO_DEVICE)); |
| ACL_CHECK(aclrtMemcpy(k_d.get(), kn_h.size(), kn_h.data(), kn_h.size(), ACL_MEMCPY_HOST_TO_DEVICE)); |
| ACL_CHECK(aclrtMemcpy(cos_d.get(), cos_h.size(), cos_h.data(), cos_h.size(), ACL_MEMCPY_HOST_TO_DEVICE)); |
| ACL_CHECK(aclrtMemcpy(sin_d.get(), sin_h.size(), sin_h.data(), sin_h.size(), ACL_MEMCPY_HOST_TO_DEVICE)); |
|
|
| |
| |
| |
| auto t_q = make_contig_tensor(q_d.get(), ACL_BF16, {1, S, Hq, Dh}); |
| auto t_k = make_contig_tensor(k_d.get(), ACL_BF16, {1, S, Hkv, Dh}); |
| auto t_cos = make_contig_tensor(cos_d.get(), ACL_BF16, {1, S, 1, Dh}); |
| auto t_sin = make_contig_tensor(sin_d.get(), ACL_BF16, {1, S, 1, Dh}); |
|
|
| int layout = 1; |
| const char* env_layout = std::getenv("LAYOUT"); |
| if (env_layout) layout = std::atoi(env_layout); |
| std::string mode = "half"; |
| const char* env_mode = std::getenv("MODE"); |
| if (env_mode) mode = env_mode; |
| bool use_v2 = (std::getenv("V2") != nullptr); |
| printf("layout=%d mode=%s v2=%d\n", layout, mode.c_str(), (int)use_v2); |
|
|
| uint64_t ws = 0; |
| aclOpExecutor* exec = nullptr; |
| if (use_v2) { |
| |
| aclnnStatus st = aclnnApplyRotaryPosEmbV2GetWorkspaceSize( |
| t_q.get(), t_k.get(), t_cos.get(), t_sin.get(), |
| layout, (char*)mode.c_str(), |
| &ws, &exec); |
| if (st != 0) { |
| fprintf(stderr, "V2 GetWS status=%d %s\n", (int)st, aclGetRecentErrMsg()); |
| return 1; |
| } |
| DeviceBuffer ws_buf; |
| if (ws > 0) ws_buf.alloc(ws); |
| ACLNN_CHECK(aclnnApplyRotaryPosEmbV2(ws_buf.get(), ws, exec, rt.stream())); |
| } else { |
| aclnnStatus st = aclnnApplyRotaryPosEmbGetWorkspaceSize( |
| t_q.get(), t_k.get(), t_cos.get(), t_sin.get(), |
| layout, |
| &ws, &exec); |
| if (st != 0) { |
| fprintf(stderr, "V1 GetWS status=%d %s\n", (int)st, aclGetRecentErrMsg()); |
| return 1; |
| } |
| DeviceBuffer ws_buf; |
| if (ws > 0) ws_buf.alloc(ws); |
| ACLNN_CHECK(aclnnApplyRotaryPosEmb(ws_buf.get(), ws, exec, rt.stream())); |
| } |
| rt.sync(); |
|
|
| |
| std::vector<uint16_t> q_out(S * Hq * Dh); |
| ACL_CHECK(aclrtMemcpy(q_out.data(), qn_h.size(), q_d.get(), qn_h.size(), ACL_MEMCPY_DEVICE_TO_HOST)); |
| auto* q_ref = (const uint16_t*)qr_h.data(); |
|
|
| double l2d = 0, l2r = 0, maxd = 0; |
| for (int i = 0; i < (int)(S*Hq*Dh); i++) { |
| float a = bf16_to_float(q_out[i]), b = bf16_to_float(q_ref[i]); |
| l2d += (a-b)*(a-b); l2r += b*b; |
| if (std::abs(a-b) > maxd) maxd = std::abs(a-b); |
| } |
| double rel = std::sqrt(l2d) / (std::sqrt(l2r) + 1e-10); |
| printf("Q rope compare: rel=%.4e max_abs=%.4f\n", rel, maxd); |
| printf(" cxx q[0,0,:4]: "); for (int i = 0; i < 4; i++) printf("%.4f ", bf16_to_float(q_out[i])); |
| printf("\n ref q[0,0,:4]: "); for (int i = 0; i < 4; i++) printf("%.4f ", bf16_to_float(q_ref[i])); printf("\n"); |
|
|
| bool ok = rel < 1e-2; |
| printf("\n%s\n", ok ? "=== test_rope PASS ===" : "=== test_rope FAIL ==="); |
| return ok ? 0 : 1; |
| } |
|
|