Initial C++ aclnn EAGER inference for Qwen3-235B-A22B MoE on Ascend 910 × 16 NPU
Browse filesPure-C++ inference runtime built directly on aclnn single-op API (no graph
compilation, no PyTorch, no ggml). Targets Qwen3-235B-A22B-Instruct-2507 BF16
with TP=16 HCCL tensor parallelism.
Quality-preserving throughput:
- Untuned baseline: 12 t/s
- Recommended (HCCL env + Fused RoPE + small ops): ~27 t/s (all prompts)
- PLD with degeneration guard: 29-45 t/s (structured long-form text)
Key components:
- 12 headers + 6 sources implementing attention/MoE/Runner/HCCL/RoPE
- Fused RoPE via aclnnApplyRotaryPosEmbV2 (layout=1, "half")
- PLD (Prompt Lookup Decoding) with degeneration guard:
low-distinct + tail-echo heuristics block loop-amplifying drafts
- bench_pld_safe.sh classifies each run as OK / LOOP_N / LOW_DIVERSITY
and separates TG stats accordingly (honest performance reporting)
- 19 unit / integration tests + end-to-end smoke test
HCCL environment (applied by tp_launch.sh):
HCCL_OP_EXPANSION_MODE=AIV + HCCL_OP_BASE_FFTS_MODE_ENABLE=1 +
TASK_QUEUE_ENABLE=2 contributes +89% TG vs default ring-only.
Known limitations:
- Does not exceed cann-recipes-infer GE graph baseline of ~54 t/s
- PLD on factual/code prompts is unreliable (disable or use bench_pld_safe.sh)
- Requires Ascend 910 initial-gen × 16 NPU and CANN 8.5.1
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
- .gitignore +48 -0
- CMakeLists.txt +110 -0
- LICENSE +176 -0
- README.md +338 -0
- external/json.hpp +0 -0
- include/acl_common.h +106 -0
- include/acl_runtime.h +41 -0
- include/aclnn_ops.h +345 -0
- include/device_weights.h +82 -0
- include/engine.h +354 -0
- include/hccl_comm.h +106 -0
- include/model_config.h +52 -0
- include/rope.h +94 -0
- include/runner.h +128 -0
- include/safetensors_loader.h +78 -0
- include/tokenizer.h +38 -0
- include/workspace_pool.h +84 -0
- scripts/bench_hccl.sh +56 -0
- scripts/bench_hccl_adv.sh +56 -0
- scripts/bench_hccl_adv2.sh +56 -0
- scripts/bench_pld.sh +69 -0
- scripts/bench_pld_k.sh +41 -0
- scripts/bench_pld_safe.sh +154 -0
- scripts/bench_tg.sh +40 -0
- scripts/export_vocab.py +85 -0
- scripts/gen_attention_reference.py +179 -0
- scripts/gen_gmm_reference.py +89 -0
- scripts/gen_mm_reference.py +23 -0
- scripts/gen_moe_reference.py +115 -0
- scripts/gen_rms_norm_reference.py +39 -0
- scripts/regen_rope_reference.py +62 -0
- scripts/tp_launch.sh +58 -0
- src/device_weights.cpp +221 -0
- src/main_cli.cpp +816 -0
- src/model_config.cpp +115 -0
- src/runner.cpp +428 -0
- src/safetensors_loader.cpp +172 -0
- src/tokenizer.cpp +176 -0
- tests/hello_acl.cpp +62 -0
- tests/test_attention_decode.cpp +319 -0
- tests/test_attention_layer.cpp +219 -0
- tests/test_batch_correctness.cpp +98 -0
- tests/test_batch_decode.cpp +85 -0
- tests/test_chat_flow.sh +72 -0
- tests/test_engine_smoke.cpp +8 -0
- tests/test_layer_forward.cpp +192 -0
- tests/test_linear_hf.cpp +73 -0
- tests/test_model_config.cpp +106 -0
- tests/test_moe_layer.cpp +676 -0
- tests/test_op_support.cpp +190 -0
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Build artifacts
|
| 2 |
+
/build/
|
| 3 |
+
*.o
|
| 4 |
+
*.obj
|
| 5 |
+
*.a
|
| 6 |
+
*.so
|
| 7 |
+
*.exe
|
| 8 |
+
|
| 9 |
+
# CMake
|
| 10 |
+
CMakeCache.txt
|
| 11 |
+
CMakeFiles/
|
| 12 |
+
cmake_install.cmake
|
| 13 |
+
Makefile
|
| 14 |
+
compile_commands.json
|
| 15 |
+
|
| 16 |
+
# Tokenizer output (regenerated from HF model)
|
| 17 |
+
/tokenizer_data/
|
| 18 |
+
*.bin
|
| 19 |
+
!tests/**/*.bin
|
| 20 |
+
|
| 21 |
+
# Reference data (regenerated by scripts/gen_*.py; too large to commit)
|
| 22 |
+
/tests/attn_data/
|
| 23 |
+
/tests/moe_data/
|
| 24 |
+
/tests/mm_data/
|
| 25 |
+
/tests/rms_norm_data/
|
| 26 |
+
/tests/poc_data/
|
| 27 |
+
|
| 28 |
+
# Runtime state
|
| 29 |
+
/tmp/
|
| 30 |
+
*.log
|
| 31 |
+
/tp_rank_*.log
|
| 32 |
+
hccl_root_info.bin
|
| 33 |
+
/tmp/hccl_root_info.bin
|
| 34 |
+
|
| 35 |
+
# Editor / IDE
|
| 36 |
+
.vscode/
|
| 37 |
+
.idea/
|
| 38 |
+
*.swp
|
| 39 |
+
*.swo
|
| 40 |
+
.DS_Store
|
| 41 |
+
|
| 42 |
+
# Python
|
| 43 |
+
__pycache__/
|
| 44 |
+
*.pyc
|
| 45 |
+
.ipynb_checkpoints/
|
| 46 |
+
|
| 47 |
+
# Benchmark output
|
| 48 |
+
bench_result*.log
|
|
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
cmake_minimum_required(VERSION 3.16)
|
| 2 |
+
project(qwen3-moe-aclnn CXX)
|
| 3 |
+
|
| 4 |
+
set(CMAKE_CXX_STANDARD 17)
|
| 5 |
+
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
| 6 |
+
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
|
| 7 |
+
|
| 8 |
+
if(NOT CMAKE_BUILD_TYPE)
|
| 9 |
+
set(CMAKE_BUILD_TYPE Release)
|
| 10 |
+
endif()
|
| 11 |
+
|
| 12 |
+
set(CMAKE_CXX_FLAGS_RELEASE "-O2 -g")
|
| 13 |
+
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall -Wno-unused-function")
|
| 14 |
+
|
| 15 |
+
# CANN paths
|
| 16 |
+
if(NOT DEFINED CANN_INSTALL_DIR)
|
| 17 |
+
if(DEFINED ENV{ASCEND_TOOLKIT_HOME})
|
| 18 |
+
set(CANN_INSTALL_DIR $ENV{ASCEND_TOOLKIT_HOME})
|
| 19 |
+
else()
|
| 20 |
+
set(CANN_INSTALL_DIR /usr/local/Ascend/ascend-toolkit/latest)
|
| 21 |
+
endif()
|
| 22 |
+
endif()
|
| 23 |
+
message(STATUS "CANN_INSTALL_DIR: ${CANN_INSTALL_DIR}")
|
| 24 |
+
|
| 25 |
+
include_directories(
|
| 26 |
+
${CANN_INSTALL_DIR}/include
|
| 27 |
+
${CANN_INSTALL_DIR}/include/aclnn
|
| 28 |
+
${CMAKE_SOURCE_DIR}/include
|
| 29 |
+
${CMAKE_SOURCE_DIR}/external
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
link_directories(${CANN_INSTALL_DIR}/lib64)
|
| 33 |
+
|
| 34 |
+
set(CANN_LIBS ascendcl nnopbase opapi opapi_transformer acl_op_compiler hccl)
|
| 35 |
+
|
| 36 |
+
# HCCL headers live under include/ but we need explicit include dir for <hccl/hccl.h>.
|
| 37 |
+
include_directories(${CANN_INSTALL_DIR}/include/hccl)
|
| 38 |
+
|
| 39 |
+
# ---- Library: qwen3-moe-aclnn core ----
|
| 40 |
+
set(LCA_SOURCES
|
| 41 |
+
src/safetensors_loader.cpp
|
| 42 |
+
src/model_config.cpp
|
| 43 |
+
src/tokenizer.cpp
|
| 44 |
+
src/device_weights.cpp
|
| 45 |
+
src/runner.cpp
|
| 46 |
+
)
|
| 47 |
+
add_library(qwen3-moe-aclnn-core STATIC ${LCA_SOURCES})
|
| 48 |
+
target_link_libraries(qwen3-moe-aclnn-core PUBLIC ${CANN_LIBS})
|
| 49 |
+
|
| 50 |
+
# ---- Binaries ----
|
| 51 |
+
add_executable(hello_acl tests/hello_acl.cpp)
|
| 52 |
+
target_link_libraries(hello_acl qwen3-moe-aclnn-core)
|
| 53 |
+
|
| 54 |
+
add_executable(test_safetensors tests/test_safetensors.cpp)
|
| 55 |
+
target_link_libraries(test_safetensors qwen3-moe-aclnn-core)
|
| 56 |
+
|
| 57 |
+
add_executable(test_model_config tests/test_model_config.cpp)
|
| 58 |
+
target_link_libraries(test_model_config qwen3-moe-aclnn-core)
|
| 59 |
+
|
| 60 |
+
add_executable(test_tokenizer tests/test_tokenizer.cpp)
|
| 61 |
+
target_link_libraries(test_tokenizer qwen3-moe-aclnn-core)
|
| 62 |
+
|
| 63 |
+
add_executable(test_rms_norm tests/test_rms_norm.cpp)
|
| 64 |
+
target_link_libraries(test_rms_norm qwen3-moe-aclnn-core)
|
| 65 |
+
|
| 66 |
+
add_executable(test_weight_load tests/test_weight_load.cpp)
|
| 67 |
+
target_link_libraries(test_weight_load qwen3-moe-aclnn-core)
|
| 68 |
+
|
| 69 |
+
add_executable(test_linear_hf tests/test_linear_hf.cpp)
|
| 70 |
+
target_link_libraries(test_linear_hf qwen3-moe-aclnn-core)
|
| 71 |
+
|
| 72 |
+
add_executable(test_rope tests/test_rope.cpp)
|
| 73 |
+
target_link_libraries(test_rope qwen3-moe-aclnn-core)
|
| 74 |
+
|
| 75 |
+
add_executable(test_rope_manual tests/test_rope_manual.cpp)
|
| 76 |
+
target_link_libraries(test_rope_manual qwen3-moe-aclnn-core)
|
| 77 |
+
|
| 78 |
+
add_executable(test_attention_layer tests/test_attention_layer.cpp)
|
| 79 |
+
target_link_libraries(test_attention_layer qwen3-moe-aclnn-core)
|
| 80 |
+
|
| 81 |
+
add_executable(test_moe_layer tests/test_moe_layer.cpp)
|
| 82 |
+
target_link_libraries(test_moe_layer qwen3-moe-aclnn-core)
|
| 83 |
+
|
| 84 |
+
add_executable(test_attention_decode tests/test_attention_decode.cpp)
|
| 85 |
+
target_link_libraries(test_attention_decode qwen3-moe-aclnn-core)
|
| 86 |
+
|
| 87 |
+
add_executable(test_engine_smoke tests/test_engine_smoke.cpp)
|
| 88 |
+
target_link_libraries(test_engine_smoke qwen3-moe-aclnn-core)
|
| 89 |
+
|
| 90 |
+
add_executable(test_layer_forward tests/test_layer_forward.cpp)
|
| 91 |
+
target_link_libraries(test_layer_forward qwen3-moe-aclnn-core)
|
| 92 |
+
|
| 93 |
+
add_executable(test_runner tests/test_runner.cpp)
|
| 94 |
+
target_link_libraries(test_runner qwen3-moe-aclnn-core)
|
| 95 |
+
|
| 96 |
+
# ---- Main CLI ----
|
| 97 |
+
add_executable(qwen3-moe-aclnn src/main_cli.cpp)
|
| 98 |
+
target_link_libraries(qwen3-moe-aclnn qwen3-moe-aclnn-core)
|
| 99 |
+
|
| 100 |
+
add_executable(test_op_support tests/test_op_support.cpp)
|
| 101 |
+
target_link_libraries(test_op_support qwen3-moe-aclnn-core)
|
| 102 |
+
|
| 103 |
+
add_executable(test_rope_fused tests/test_rope_fused.cpp)
|
| 104 |
+
target_link_libraries(test_rope_fused qwen3-moe-aclnn-core)
|
| 105 |
+
|
| 106 |
+
add_executable(test_batch_decode tests/test_batch_decode.cpp)
|
| 107 |
+
target_link_libraries(test_batch_decode qwen3-moe-aclnn-core)
|
| 108 |
+
|
| 109 |
+
add_executable(test_batch_correctness tests/test_batch_correctness.cpp)
|
| 110 |
+
target_link_libraries(test_batch_correctness qwen3-moe-aclnn-core)
|
|
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Apache License
|
| 2 |
+
Version 2.0, January 2004
|
| 3 |
+
http://www.apache.org/licenses/
|
| 4 |
+
|
| 5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 6 |
+
|
| 7 |
+
1. Definitions.
|
| 8 |
+
|
| 9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
| 10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
| 11 |
+
|
| 12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
| 13 |
+
the copyright owner that is granting the License.
|
| 14 |
+
|
| 15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
| 16 |
+
other entities that control, are controlled by, or are under common
|
| 17 |
+
control with that entity. For the purposes of this definition,
|
| 18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
| 19 |
+
direction or management of such entity, whether by contract or
|
| 20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 22 |
+
|
| 23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
| 24 |
+
exercising permissions granted by this License.
|
| 25 |
+
|
| 26 |
+
"Source" form shall mean the preferred form for making modifications,
|
| 27 |
+
including but not limited to software source code, documentation
|
| 28 |
+
source, and configuration files.
|
| 29 |
+
|
| 30 |
+
"Object" form shall mean any form resulting from mechanical
|
| 31 |
+
transformation or translation of a Source form, including but
|
| 32 |
+
not limited to compiled object code, generated documentation,
|
| 33 |
+
and conversions to other media types.
|
| 34 |
+
|
| 35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
| 36 |
+
Object form, made available under the License, as indicated by a
|
| 37 |
+
copyright notice that is included in or attached to the work
|
| 38 |
+
(an example is provided in the Appendix below).
|
| 39 |
+
|
| 40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
| 41 |
+
form, that is based on (or derived from) the Work and for which the
|
| 42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
| 43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
| 44 |
+
of this License, Derivative Works shall not include works that remain
|
| 45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
| 46 |
+
the Work and Derivative Works thereof.
|
| 47 |
+
|
| 48 |
+
"Contribution" shall mean any work of authorship, including
|
| 49 |
+
the original version of the Work and any modifications or additions
|
| 50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
| 51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
| 52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
| 53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
| 54 |
+
means any form of electronic, verbal, or written communication sent
|
| 55 |
+
to the Licensor or its representatives, including but not limited to
|
| 56 |
+
communication on electronic mailing lists, source code control systems,
|
| 57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
| 58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
| 59 |
+
excluding communication that is conspicuously marked or otherwise
|
| 60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
| 61 |
+
|
| 62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
| 64 |
+
subsequently incorporated within the Work.
|
| 65 |
+
|
| 66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
| 70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
| 71 |
+
Work and such Derivative Works in Source or Object form.
|
| 72 |
+
|
| 73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
| 74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 76 |
+
(except as stated in this section) patent license to make, have made,
|
| 77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 78 |
+
where such license applies only to those patent claims licensable
|
| 79 |
+
by such Contributor that are necessarily infringed by their
|
| 80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
| 81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
| 82 |
+
institute patent litigation against any entity (including a
|
| 83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 84 |
+
or a Contribution incorporated within the Work constitutes direct
|
| 85 |
+
or contributory patent infringement, then any patent licenses
|
| 86 |
+
granted to You under this License for that Work shall terminate
|
| 87 |
+
as of the date such litigation is filed.
|
| 88 |
+
|
| 89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
| 90 |
+
Work or Derivative Works thereof in any medium, with or without
|
| 91 |
+
modifications, and in Source or Object form, provided that You
|
| 92 |
+
meet the following conditions:
|
| 93 |
+
|
| 94 |
+
(a) You must give any other recipients of the Work or
|
| 95 |
+
Derivative Works a copy of this License; and
|
| 96 |
+
|
| 97 |
+
(b) You must cause any modified files to carry prominent notices
|
| 98 |
+
stating that You changed the files; and
|
| 99 |
+
|
| 100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
| 101 |
+
that You distribute, all copyright, patent, trademark, and
|
| 102 |
+
attribution notices from the Source form of the Work,
|
| 103 |
+
excluding those notices that do not pertain to any part of
|
| 104 |
+
the Derivative Works; and
|
| 105 |
+
|
| 106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
| 107 |
+
distribution, then any Derivative Works that You distribute must
|
| 108 |
+
include a readable copy of the attribution notices contained
|
| 109 |
+
within such NOTICE file, excluding those notices that do not
|
| 110 |
+
pertain to any part of the Derivative Works, in at least one
|
| 111 |
+
of the following places: within a NOTICE text file distributed
|
| 112 |
+
as part of the Derivative Works; within the Source form or
|
| 113 |
+
documentation, if provided along with the Derivative Works; or,
|
| 114 |
+
within a display generated by the Derivative Works, if and
|
| 115 |
+
wherever such third-party notices normally appear. The contents
|
| 116 |
+
of the NOTICE file are for informational purposes only and
|
| 117 |
+
do not modify the License. You may add Your own attribution
|
| 118 |
+
notices within Derivative Works that You distribute, alongside
|
| 119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
| 120 |
+
that such additional attribution notices cannot be construed
|
| 121 |
+
as modifying the License.
|
| 122 |
+
|
| 123 |
+
You may add Your own copyright statement to Your modifications and
|
| 124 |
+
may provide additional or different license terms and conditions
|
| 125 |
+
for use, reproduction, or distribution of Your modifications, or
|
| 126 |
+
for any such Derivative Works as a whole, provided Your use,
|
| 127 |
+
reproduction, and distribution of the Work otherwise complies with
|
| 128 |
+
the conditions stated in this License.
|
| 129 |
+
|
| 130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
| 132 |
+
by You to the Licensor shall be under the terms and conditions of
|
| 133 |
+
this License, without any additional terms or conditions.
|
| 134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
| 135 |
+
the terms of any separate license agreement you may have executed
|
| 136 |
+
with Licensor regarding such Contributions.
|
| 137 |
+
|
| 138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
| 139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
| 140 |
+
except as required for describing the origin of the Work and
|
| 141 |
+
reproducing the content of the NOTICE file.
|
| 142 |
+
|
| 143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 144 |
+
agreed to in writing, Licensor provides the Work (and each
|
| 145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 147 |
+
implied, including, without limitation, any warranties or conditions
|
| 148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 150 |
+
appropriateness of using or redistributing the Work and assume any
|
| 151 |
+
risks associated with Your exercise of permissions under this License.
|
| 152 |
+
|
| 153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
| 154 |
+
whether in tort (including negligence), contract, or otherwise,
|
| 155 |
+
unless required by applicable law (such as deliberate and grossly
|
| 156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
| 157 |
+
liable to You for damages, including any direct, indirect, special,
|
| 158 |
+
incidental, or consequential damages of any character arising as a
|
| 159 |
+
result of this License or out of the use or inability to use the
|
| 160 |
+
Work (including but not limited to damages for loss of goodwill,
|
| 161 |
+
work stoppage, computer failure or malfunction, or any and all
|
| 162 |
+
other commercial damages or losses), even if such Contributor
|
| 163 |
+
has been advised of the possibility of such damages.
|
| 164 |
+
|
| 165 |
+
9. Accepting Warranty or Support. While redistributing the Work or
|
| 166 |
+
Derivative Works thereof, You may choose to offer, and charge a
|
| 167 |
+
fee for, acceptance of support, warranty, indemnity, or other
|
| 168 |
+
liability obligations and/or rights consistent with this License.
|
| 169 |
+
However, in accepting such obligations, You may act only on Your
|
| 170 |
+
own behalf and on Your sole responsibility, not on behalf of any
|
| 171 |
+
other Contributor, and only if You agree to indemnify, defend,
|
| 172 |
+
and hold each Contributor harmless for any liability incurred by,
|
| 173 |
+
or claims asserted against, such Contributor by reason of your
|
| 174 |
+
accepting any such warranty or support.
|
| 175 |
+
|
| 176 |
+
END OF TERMS AND CONDITIONS
|
|
@@ -0,0 +1,338 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# qwen3-moe-aclnn
|
| 2 |
+
|
| 3 |
+
Pure C++ inference of **Qwen3-235B-A22B-Instruct** BF16 on **Ascend 910 × 16 NPU**, built directly on the aclnn EAGER API (no graph compilation, no PyTorch, no ggml).
|
| 4 |
+
|
| 5 |
+
---
|
| 6 |
+
|
| 7 |
+
## Performance
|
| 8 |
+
|
| 9 |
+
Measured on Ascend 910 initial-gen × 16 NPU (TP=16) with Qwen3-235B-A22B-Instruct-2507 BF16 weights.
|
| 10 |
+
All numbers are **quality-preserving TG** (output was manually verified); greedy `temperature=0`.
|
| 11 |
+
|
| 12 |
+
| Configuration | TG | Applicable prompts |
|
| 13 |
+
|---|---|---|
|
| 14 |
+
| Untuned baseline | 12 t/s | All |
|
| 15 |
+
| **Default recommended** (no PLD) | **~27 t/s** | **All prompts, stable output** |
|
| 16 |
+
| PLD with degeneration guard | 29-45 t/s | Structured text (essays, long-form answers) |
|
| 17 |
+
| PLD on creative prompts | 25-40 t/s | Stories / varied generation |
|
| 18 |
+
| PLD on factual / code prompts | unstable (21-95 t/s, high variance) | Not recommended |
|
| 19 |
+
|
| 20 |
+
Reference: `cann-recipes-infer` GE graph baseline reports ~54 t/s on the same hardware. **This project does not exceed that baseline** — it trades some peak speed for (a) no graph compilation, (b) no PyTorch dependency, (c) full control over operator scheduling.
|
| 21 |
+
|
| 22 |
+
### Key optimizations that contributed (in order of magnitude)
|
| 23 |
+
|
| 24 |
+
| Rank | Optimization | Gain | Where |
|
| 25 |
+
|---|---|---|---|
|
| 26 |
+
| 🥇 | HCCL env tuning (`AIV` + `FFTS` + `TASK_QUEUE=2`) | +89% (12→23 t/s) | `scripts/tp_launch.sh` |
|
| 27 |
+
| 🥈 | Fused RoPE via `aclnnApplyRotaryPosEmbV2` | +17% (23→27 t/s) | `include/rope.h` |
|
| 28 |
+
| 🥉 | Prompt Lookup Decoding (PLD) w/ degeneration guard | +10-60% on applicable prompts | `src/main_cli.cpp` |
|
| 29 |
+
| ○ | Device-side topk-w normalize, MoE argsort, cos/sin cache | ~+15% cumulative | `include/engine.h` |
|
| 30 |
+
| ○ | WorkspacePool (thread-local + retain-old) | reduces alloc overhead | `include/workspace_pool.h` |
|
| 31 |
+
|
| 32 |
+
---
|
| 33 |
+
|
| 34 |
+
## Architecture
|
| 35 |
+
|
| 36 |
+
**Model**: Qwen3-235B-A22B, 94 layers, 128 experts (top-k=8), GQA (64 Q heads, 4 KV heads), BF16.
|
| 37 |
+
|
| 38 |
+
**Parallelism**: TP=16 via HCCL ring AllReduce. KV heads sharded 1-per-rank (since 4 KV heads < 16 ranks, Q heads 0-3 on each rank share KV head 0).
|
| 39 |
+
|
| 40 |
+
**Execution**: aclnn EAGER mode — every op goes through `aclnn*` single-op API with workspace pool; no graph capture, no GE IR. Async stream execution with `TASK_QUEUE_ENABLE=2` for kernel submission overlap.
|
| 41 |
+
|
| 42 |
+
**Tokenizer**: Uses HuggingFace `transformers` via a Python subprocess for encoding; vocab decode is pure C++ from an exported `vocab.bin`.
|
| 43 |
+
|
| 44 |
+
### Per-layer forward flow
|
| 45 |
+
|
| 46 |
+
```
|
| 47 |
+
x_in [S, D=4096]
|
| 48 |
+
↓
|
| 49 |
+
┌── Attention branch (TP: Q_DIM=512=4h×128, KV_DIM=128=1h×128) ──┐
|
| 50 |
+
│ RmsNorm(input_layernorm)
|
| 51 |
+
│ linear_hf q_proj / k_proj / v_proj → q, k, v
|
| 52 |
+
│ Per-head RmsNorm q_norm, k_norm
|
| 53 |
+
│ Fused RoPE: aclnnApplyRotaryPosEmbV2 (layout=1, "half")
|
| 54 |
+
│ Append K, V to per-layer KV cache
|
| 55 |
+
│ Mask selection:
|
| 56 |
+
│ prefill: 2048×2048 causal + sparse_mode=3
|
| 57 |
+
│ decode S=1: nullptr + sparse_mode=0
|
| 58 |
+
│ batch decode: [1,1,S,past+S] custom bool mask + sparse_mode=0
|
| 59 |
+
│ FIAS (aclnnFusedInferAttentionScore)
|
| 60 |
+
│ o_proj linear_hf → partial per-rank
|
| 61 |
+
│ HCCL AllReduce (ring + AIV + FFTS) → full
|
| 62 |
+
└─────────┘
|
| 63 |
+
↓ residual add
|
| 64 |
+
┌── MoE branch ──┐
|
| 65 |
+
│ RmsNorm(post_attention_layernorm)
|
| 66 |
+
│ router linear_hf → logits [S, 128]
|
| 67 |
+
│ moe_gating_topk_softmax → topk_w[S,8], topk_idx[S,8]
|
| 68 |
+
│ Device-side normalize (reduce_sum + adds + cast + div)
|
| 69 |
+
│ moe_init_routing_v3 → expanded_x, expanded_ri, tokens_per_expert
|
| 70 |
+
│ grouped_matmul_v4 gate/up/down (SwiGLU activation)
|
| 71 |
+
│ Device-side argsort × 2 → fwd permutation (avoids host sync)
|
| 72 |
+
│ IndexSelect → packed
|
| 73 |
+
│ Broadcast-mul by topk_w + ReduceSum axis=1
|
| 74 |
+
│ HCCL AllReduce → full
|
| 75 |
+
└─────────┘
|
| 76 |
+
↓ residual add
|
| 77 |
+
x_out
|
| 78 |
+
```
|
| 79 |
+
|
| 80 |
+
---
|
| 81 |
+
|
| 82 |
+
## Model weights
|
| 83 |
+
|
| 84 |
+
This project targets **Qwen3-235B-A22B-Instruct-2507** (BF16). About **470 GB** of safetensors shards.
|
| 85 |
+
|
| 86 |
+
**Download sources**:
|
| 87 |
+
- HuggingFace: https://huggingface.co/Qwen/Qwen3-235B-A22B-Instruct-2507
|
| 88 |
+
- ModelScope: https://www.modelscope.cn/models/Qwen/Qwen3-235B-A22B-Instruct-2507
|
| 89 |
+
|
| 90 |
+
Download via `huggingface-cli` or `modelscope` CLI:
|
| 91 |
+
```bash
|
| 92 |
+
# HuggingFace
|
| 93 |
+
huggingface-cli download Qwen/Qwen3-235B-A22B-Instruct-2507 --local-dir /path/to/Qwen3-235B-A22B-Instruct-2507-BF16
|
| 94 |
+
|
| 95 |
+
# ModelScope
|
| 96 |
+
modelscope download --model Qwen/Qwen3-235B-A22B-Instruct-2507 --local_dir /path/to/Qwen3-235B-A22B-Instruct-2507-BF16
|
| 97 |
+
```
|
| 98 |
+
|
| 99 |
+
**Weights format**: the binary reads HuggingFace `.safetensors` shards (multi-shard mmap), `config.json`, and `tokenizer.json` directly from the model directory. No conversion step is needed — point `--model-dir` at the downloaded directory.
|
| 100 |
+
|
| 101 |
+
**Expected directory contents**:
|
| 102 |
+
```
|
| 103 |
+
Qwen3-235B-A22B-Instruct-2507-BF16/
|
| 104 |
+
├── config.json
|
| 105 |
+
├── tokenizer.json
|
| 106 |
+
├── tokenizer_config.json
|
| 107 |
+
├── model-00001-of-000XX.safetensors
|
| 108 |
+
├── ...
|
| 109 |
+
└── model.safetensors.index.json
|
| 110 |
+
```
|
| 111 |
+
|
| 112 |
+
---
|
| 113 |
+
|
| 114 |
+
## Build
|
| 115 |
+
|
| 116 |
+
```bash
|
| 117 |
+
source /usr/local/Ascend/ascend-toolkit/set_env.sh
|
| 118 |
+
cmake -B build
|
| 119 |
+
cmake --build build -j8 --target qwen3-moe-aclnn
|
| 120 |
+
```
|
| 121 |
+
|
| 122 |
+
**Requires**:
|
| 123 |
+
- CANN 8.5.1 or compatible
|
| 124 |
+
- Python 3 + `transformers` + `torch_npu` (for tokenizer subprocess and reference-data generation only)
|
| 125 |
+
- C++17 compiler
|
| 126 |
+
- Ascend 910 × 16 NPU
|
| 127 |
+
- nlohmann/json (bundled as `external/json.hpp`)
|
| 128 |
+
|
| 129 |
+
**Python environment setup** — the tokenizer calls a Python subprocess. Override the activation command via `QWEN3_PYENV_INIT` if your conda / venv layout differs from the default:
|
| 130 |
+
```bash
|
| 131 |
+
export QWEN3_PYENV_INIT="source /opt/my_conda/etc/profile.d/conda.sh && conda activate my_env && "
|
| 132 |
+
```
|
| 133 |
+
If unset, the default tries `${HOME}/miniconda3` with env `qwen3` and auto-sources the Ascend toolkit.
|
| 134 |
+
|
| 135 |
+
---
|
| 136 |
+
|
| 137 |
+
## Quick-start inference
|
| 138 |
+
|
| 139 |
+
```bash
|
| 140 |
+
# 1. Export tokenizer vocab to binary (one-time setup)
|
| 141 |
+
python3 scripts/export_vocab.py /path/to/Qwen3-235B-A22B-Instruct-2507-BF16
|
| 142 |
+
|
| 143 |
+
# 2. Run inference (TP=16)
|
| 144 |
+
./scripts/tp_launch.sh 16 ./build/qwen3-moe-aclnn \
|
| 145 |
+
--model-dir /path/to/Qwen3-235B-A22B-Instruct-2507-BF16 \
|
| 146 |
+
--prompt "The capital of France is" \
|
| 147 |
+
--n-predict 100 \
|
| 148 |
+
--temperature 0 \
|
| 149 |
+
--vocab tokenizer_data/vocab.bin
|
| 150 |
+
```
|
| 151 |
+
|
| 152 |
+
Expected: ~27 t/s, coherent output.
|
| 153 |
+
|
| 154 |
+
### Recommended flags by use case
|
| 155 |
+
|
| 156 |
+
**Universal default (stable, any prompt)** — no PLD:
|
| 157 |
+
```bash
|
| 158 |
+
./scripts/tp_launch.sh 16 ./build/qwen3-moe-aclnn --model-dir ... --temperature 0 --no-stream
|
| 159 |
+
```
|
| 160 |
+
|
| 161 |
+
**Structured / long-form (essays, explanations)** — PLD with guard gives +60-90%:
|
| 162 |
+
```bash
|
| 163 |
+
./scripts/tp_launch.sh 16 ./build/qwen3-moe-aclnn --model-dir ... --pld --temperature 0 --no-stream
|
| 164 |
+
```
|
| 165 |
+
|
| 166 |
+
**Interactive REPL (multi-turn chat)**:
|
| 167 |
+
```bash
|
| 168 |
+
./scripts/tp_launch.sh 16 ./build/qwen3-moe-aclnn --model-dir ... \
|
| 169 |
+
--interactive --chat --temperature 0.7 --top-p 0.8
|
| 170 |
+
```
|
| 171 |
+
|
| 172 |
+
---
|
| 173 |
+
|
| 174 |
+
## PLD degeneration guard
|
| 175 |
+
|
| 176 |
+
Prompt Lookup Decoding speeds up generation by having the model verify a batch of "draft" tokens in a single forward pass. The drafts are copied from the generation history via n-gram match.
|
| 177 |
+
|
| 178 |
+
**Known failure mode**: on prompts the model tends to repeat on (factual Q&A, code generation), the n-gram match feeds the model's own repetition back as drafts, creating a positive feedback loop that accelerates degenerate output. Early versions of this project reported misleading peak TG numbers driven by this loop.
|
| 179 |
+
|
| 180 |
+
**This project's guard** blocks suspect drafts with two heuristics:
|
| 181 |
+
|
| 182 |
+
1. **low-distinct**: draft's distinct-token count < threshold → reject
|
| 183 |
+
2. **tail-echo**: all of last N hist tokens equal draft[0] → reject
|
| 184 |
+
|
| 185 |
+
Rejected drafts fall back to single-token decode. A `[warn]` line is emitted once if the generated tail shows 8 consecutive identical tokens.
|
| 186 |
+
|
| 187 |
+
Flags:
|
| 188 |
+
```
|
| 189 |
+
--pld enable PLD (opt-in)
|
| 190 |
+
--pld-k N draft window size (default: 10)
|
| 191 |
+
--pld-ngram N n-gram match size (default: 1, with multi-level fallback)
|
| 192 |
+
--pld-min-hist N skip PLD until history >= N tokens (default: 20)
|
| 193 |
+
--pld-no-guard disable the degeneration guard (dangerous: can produce dead loops)
|
| 194 |
+
--pld-guard-distinct N minimum distinct tokens in draft (default: 3)
|
| 195 |
+
--pld-guard-tail N tail-echo window (default: 6)
|
| 196 |
+
--pld-loop-warn N emit warning on N consecutive identical tokens (default: 8)
|
| 197 |
+
```
|
| 198 |
+
|
| 199 |
+
**Honest benchmarking**: use `scripts/bench_pld_safe.sh`, which classifies each run's output as OK / LOOP_N / LOW_DIVERSITY and separates TG statistics for OK-only vs degraded runs.
|
| 200 |
+
|
| 201 |
+
---
|
| 202 |
+
|
| 203 |
+
## Correctness verification
|
| 204 |
+
|
| 205 |
+
15+ unit / integration tests checked against Python (HuggingFace Transformers) reference:
|
| 206 |
+
|
| 207 |
+
```bash
|
| 208 |
+
./build/test_attention_layer # rel=4.9e-4 vs Python prefill
|
| 209 |
+
./build/test_attention_decode # rel=0 (bit-exact)
|
| 210 |
+
./build/test_moe_layer # rel=3.6e-3
|
| 211 |
+
./build/test_layer_forward # full single layer
|
| 212 |
+
./build/test_runner # multi-layer runner
|
| 213 |
+
./build/test_rope_fused # aclnnApplyRotaryPosEmbV2 vs manual HF rotate_half
|
| 214 |
+
./build/test_batch_decode # S=1..8 timing
|
| 215 |
+
./build/test_batch_correctness # argmax consistency
|
| 216 |
+
./build/test_op_support # 910-specific op availability
|
| 217 |
+
# Integration smoke:
|
| 218 |
+
./tests/test_chat_flow.sh # 7/7 PASS
|
| 219 |
+
```
|
| 220 |
+
|
| 221 |
+
Tests expect reference data under `tests/<name>_data/` generated by `scripts/gen_*_reference.py`. See each script's docstring.
|
| 222 |
+
|
| 223 |
+
---
|
| 224 |
+
|
| 225 |
+
## Environment tuning (auto-applied by `tp_launch.sh`)
|
| 226 |
+
|
| 227 |
+
```bash
|
| 228 |
+
HCCL_WHITELIST_DISABLE=1
|
| 229 |
+
HCCL_ALGO=level0:ring # ring, not fullmesh (fullmesh causes garbled output)
|
| 230 |
+
HCCL_BUFFSIZE=200 # sweet spot; 100 and 400 both slower
|
| 231 |
+
HCCL_OP_EXPANSION_MODE=AIV # key: AI Vector cores participate in reduce scheduling
|
| 232 |
+
HCCL_OP_BASE_FFTS_MODE_ENABLE=1 # key: Fast Frequently-used Transfer Scheduling
|
| 233 |
+
TASK_QUEUE_ENABLE=2 # key: aggressive async task submission
|
| 234 |
+
```
|
| 235 |
+
|
| 236 |
+
Removing any of the three "key" env vars drops TG by 20-40%.
|
| 237 |
+
|
| 238 |
+
---
|
| 239 |
+
|
| 240 |
+
## Directory layout
|
| 241 |
+
|
| 242 |
+
```
|
| 243 |
+
include/
|
| 244 |
+
├── acl_common.h RAII wrappers, DeviceBuffer, make_contig_tensor
|
| 245 |
+
├── aclnn_ops.h single-op wrappers + WorkspacePool integration
|
| 246 |
+
├── acl_runtime.h AclRuntime (device + stream management)
|
| 247 |
+
├── device_weights.h safetensors → device loading + TP sharding
|
| 248 |
+
├── engine.h attention_forward + moe_forward + RopeCache
|
| 249 |
+
├── hccl_comm.h HCCL init + allreduce + broadcast
|
| 250 |
+
├── model_config.h Qwen3 hyperparameters + compute_derived
|
| 251 |
+
├── rope.h apply_rope_fused (aclnnApplyRotaryPosEmbV2 wrapper)
|
| 252 |
+
├── runner.h Runner class (prefill/decode/decode_batch/rewind/profile)
|
| 253 |
+
├── safetensors_loader.h multi-shard safetensors mmap parser
|
| 254 |
+
├── tokenizer.h vocab decode + Python subprocess encode
|
| 255 |
+
└── workspace_pool.h thread-local aclnn workspace pool (retain-old)
|
| 256 |
+
|
| 257 |
+
src/
|
| 258 |
+
├── device_weights.cpp load_attention (GQA fix), load_moe (permute sync fix)
|
| 259 |
+
├── main_cli.cpp CLI entry + PLD main loop + degeneration guard + multi-turn
|
| 260 |
+
├── model_config.cpp compute_derived (GQA KV sharding)
|
| 261 |
+
├── runner.cpp Runner (build_batch_decode_mask_ etc.)
|
| 262 |
+
├── safetensors_loader.cpp
|
| 263 |
+
└── tokenizer.cpp
|
| 264 |
+
|
| 265 |
+
scripts/
|
| 266 |
+
├── tp_launch.sh production launcher (auto-applies HCCL env)
|
| 267 |
+
├── bench_tg.sh stable N-run TG measurement
|
| 268 |
+
├── bench_pld_safe.sh PLD benchmark with output-correctness classifier
|
| 269 |
+
├── bench_hccl[_adv].sh HCCL parameter sweep
|
| 270 |
+
├── bench_pld[_k].sh PLD K × ngram sweep (legacy, prefer bench_pld_safe.sh)
|
| 271 |
+
├── export_vocab.py vocab.bin exporter from HF tokenizer
|
| 272 |
+
└── gen_*_reference.py per-op Python reference data generators
|
| 273 |
+
|
| 274 |
+
tests/
|
| 275 |
+
├── test_attention_* attention correctness (prefill / decode)
|
| 276 |
+
├── test_moe_layer MoE correctness
|
| 277 |
+
├── test_layer_forward full single layer
|
| 278 |
+
├── test_runner multi-layer Runner
|
| 279 |
+
├── test_rope_fused fused RoPE vs manual HF
|
| 280 |
+
├── test_batch_* batch decode timing + correctness
|
| 281 |
+
├── test_op_support 910-specific op availability probe
|
| 282 |
+
└── test_chat_flow.sh end-to-end integration smoke
|
| 283 |
+
```
|
| 284 |
+
|
| 285 |
+
---
|
| 286 |
+
|
| 287 |
+
## CLI reference
|
| 288 |
+
|
| 289 |
+
```
|
| 290 |
+
--model-dir <path> (required) HF safetensors directory
|
| 291 |
+
--prompt "<text>" prompt text
|
| 292 |
+
--prompt-file FILE read prompt from file (avoids shell-escape issues)
|
| 293 |
+
--n-predict N maximum tokens to generate
|
| 294 |
+
--tp-size N tensor parallelism (or set TP_SIZE env)
|
| 295 |
+
--max-seq N KV cache + context cap (default: 512)
|
| 296 |
+
--temperature F 0 = greedy; typical 0.7
|
| 297 |
+
--top-k N 0 = disabled
|
| 298 |
+
--top-p F 1.0 = disabled
|
| 299 |
+
--seed N 0 = time-based
|
| 300 |
+
--chat apply Qwen3 chat template
|
| 301 |
+
--system "<text>" system role text (with --chat)
|
| 302 |
+
--interactive, -i REPL mode (multi-turn memory with --chat)
|
| 303 |
+
--reset force stateless REPL (reset KV between turns)
|
| 304 |
+
--no-stream batch-print final text instead of per-token streaming
|
| 305 |
+
--vocab <path> vocab.bin path (default: tokenizer_data/vocab.bin)
|
| 306 |
+
--pld* see "PLD degeneration guard" section
|
| 307 |
+
```
|
| 308 |
+
|
| 309 |
+
---
|
| 310 |
+
|
| 311 |
+
## Known limitations
|
| 312 |
+
|
| 313 |
+
- **Not yet reaching cann-recipes GE graph 54 t/s baseline** (currently ~27 t/s stable / up to ~45 t/s PLD).
|
| 314 |
+
Closing the gap requires one of: (a) real graph compilation, (b) fused collectives (`MatmulAllReduce`, `GroupedMatmulAllReduce`) which are absent on 910 initial-gen, (c) migration to 910B/A2/A3.
|
| 315 |
+
- **Only `tp_size` ∈ {1, 2, 4, 8, 16}** supported. Values that don't evenly divide 64 Q heads will error.
|
| 316 |
+
- **PLD on factual/code prompts is unreliable** — either produces baseline TG (guard rejects most drafts) or enters partial degeneration the classifier may not catch at low-severity. Use `bench_pld_safe.sh` to evaluate honestly.
|
| 317 |
+
- **Tokenizer requires Python subprocess** — adds ~1s startup for first encode. Override via `QWEN3_PYENV_INIT` env if default conda path doesn't match.
|
| 318 |
+
- **NPU performance has high run-to-run variance** (up to 4× in some configurations) due to BF16 + MoE intrinsic non-determinism and shared hardware resources. Report medians over ≥5 runs.
|
| 319 |
+
|
| 320 |
+
---
|
| 321 |
+
|
| 322 |
+
## Future directions (prioritized)
|
| 323 |
+
|
| 324 |
+
1. **Draft Model Speculative Decoding** with Qwen3-0.6B — more stable accept rate than n-gram PLD, expected +60-100% TG across prompt types (1-2 week implementation).
|
| 325 |
+
2. **HCCL AllReduce / compute overlap** — ~+10-15% in theory, limited by EAGER path serial dependencies.
|
| 326 |
+
3. **KV cache INT8 quantization** — reduces memory-bandwidth pressure, ~+15-25% on long contexts (pending 910-initial-gen op support verification).
|
| 327 |
+
4. **W8 weight quantization** — ~+10-20% if aclnn quantization kernels exist on 910 initial-gen.
|
| 328 |
+
|
| 329 |
+
Not recommended:
|
| 330 |
+
- `aclmdlRI` stream-capture-style graph recording (POC proved 1.13× ceiling, not worth the engineering cost).
|
| 331 |
+
- Custom AscendC fused ops (high maintenance cost unless dedicated kernel engineer).
|
| 332 |
+
- torchair / torch.compile migration (breaks pure-C++ design).
|
| 333 |
+
|
| 334 |
+
---
|
| 335 |
+
|
| 336 |
+
## License
|
| 337 |
+
|
| 338 |
+
Apache License 2.0 — see `LICENSE`.
|
|
The diff for this file is too large to render.
See raw diff
|
|
|
|
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
#include <acl/acl.h>
|
| 3 |
+
#include <aclnn/acl_meta.h>
|
| 4 |
+
#include <cstdio>
|
| 5 |
+
#include <cstdlib>
|
| 6 |
+
#include <memory>
|
| 7 |
+
#include <string>
|
| 8 |
+
#include <vector>
|
| 9 |
+
|
| 10 |
+
#define ACL_CHECK(x) do { \
|
| 11 |
+
aclError __e = (x); \
|
| 12 |
+
if (__e != ACL_ERROR_NONE) { \
|
| 13 |
+
fprintf(stderr, "ACL error %d at %s:%d : %s\n", __e, __FILE__, __LINE__, #x); \
|
| 14 |
+
std::abort(); \
|
| 15 |
+
} \
|
| 16 |
+
} while(0)
|
| 17 |
+
|
| 18 |
+
#define ACLNN_CHECK(x) do { \
|
| 19 |
+
aclnnStatus __e = (x); \
|
| 20 |
+
if (__e != 0) { \
|
| 21 |
+
const char* __msg = aclGetRecentErrMsg(); \
|
| 22 |
+
fprintf(stderr, "aclnn error %d at %s:%d : %s\n msg: %s\n", (int)__e, __FILE__, __LINE__, #x, __msg ? __msg : "(null)"); \
|
| 23 |
+
std::abort(); \
|
| 24 |
+
} \
|
| 25 |
+
} while(0)
|
| 26 |
+
|
| 27 |
+
// RAII wrapper for aclTensor: call aclDestroyTensor on dtor
|
| 28 |
+
struct AclTensorDel { void operator()(aclTensor* t) const { if (t) aclDestroyTensor(t); } };
|
| 29 |
+
using AclTensorPtr = std::unique_ptr<aclTensor, AclTensorDel>;
|
| 30 |
+
|
| 31 |
+
struct AclTensorListDel { void operator()(aclTensorList* t) const { if (t) aclDestroyTensorList(t); } };
|
| 32 |
+
using AclTensorListPtr = std::unique_ptr<aclTensorList, AclTensorListDel>;
|
| 33 |
+
|
| 34 |
+
struct AclIntArrayDel { void operator()(aclIntArray* a) const { if (a) aclDestroyIntArray(a); } };
|
| 35 |
+
using AclIntArrayPtr = std::unique_ptr<aclIntArray, AclIntArrayDel>;
|
| 36 |
+
|
| 37 |
+
// Create ACL tensor with explicit row-major shape (outermost leftmost) and element strides.
|
| 38 |
+
// NOTE: stride is in ELEMENTS, not bytes.
|
| 39 |
+
inline AclTensorPtr make_acl_tensor(void* data, aclDataType dt,
|
| 40 |
+
const std::vector<int64_t>& shape,
|
| 41 |
+
const std::vector<int64_t>& stride_elems,
|
| 42 |
+
aclFormat fmt = ACL_FORMAT_ND) {
|
| 43 |
+
int64_t n = (int64_t)shape.size();
|
| 44 |
+
int64_t storage_len = 1;
|
| 45 |
+
for (int i = 0; i < n; i++) storage_len += (shape[i] - 1) * stride_elems[i];
|
| 46 |
+
aclTensor* t = aclCreateTensor(
|
| 47 |
+
shape.data(), (uint64_t)n, dt,
|
| 48 |
+
stride_elems.data(), 0, fmt,
|
| 49 |
+
&storage_len, 1, data);
|
| 50 |
+
return AclTensorPtr(t);
|
| 51 |
+
}
|
| 52 |
+
|
| 53 |
+
// Default contiguous strides for row-major tensor: stride[i] = product of shape[i+1..n-1]
|
| 54 |
+
inline std::vector<int64_t> contiguous_strides(const std::vector<int64_t>& shape) {
|
| 55 |
+
int n = (int)shape.size();
|
| 56 |
+
std::vector<int64_t> s(n);
|
| 57 |
+
int64_t acc = 1;
|
| 58 |
+
for (int i = n - 1; i >= 0; --i) {
|
| 59 |
+
s[i] = acc;
|
| 60 |
+
acc *= shape[i];
|
| 61 |
+
}
|
| 62 |
+
return s;
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
inline AclTensorPtr make_contig_tensor(void* data, aclDataType dt,
|
| 66 |
+
const std::vector<int64_t>& shape,
|
| 67 |
+
aclFormat fmt = ACL_FORMAT_ND) {
|
| 68 |
+
return make_acl_tensor(data, dt, shape, contiguous_strides(shape), fmt);
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
inline size_t dtype_size(aclDataType dt) {
|
| 72 |
+
switch (dt) {
|
| 73 |
+
case ACL_FLOAT: return 4;
|
| 74 |
+
case ACL_FLOAT16: return 2;
|
| 75 |
+
case ACL_BF16: return 2;
|
| 76 |
+
case ACL_INT8: return 1;
|
| 77 |
+
case ACL_INT32: return 4;
|
| 78 |
+
case ACL_INT64: return 8;
|
| 79 |
+
default: return 0;
|
| 80 |
+
}
|
| 81 |
+
}
|
| 82 |
+
|
| 83 |
+
// Device buffer RAII: allocates via aclrtMalloc, frees in dtor
|
| 84 |
+
struct DeviceBuffer {
|
| 85 |
+
void* ptr = nullptr;
|
| 86 |
+
size_t size = 0;
|
| 87 |
+
|
| 88 |
+
DeviceBuffer() = default;
|
| 89 |
+
explicit DeviceBuffer(size_t bytes) { alloc(bytes); }
|
| 90 |
+
~DeviceBuffer() { if (ptr) aclrtFree(ptr); }
|
| 91 |
+
DeviceBuffer(const DeviceBuffer&) = delete;
|
| 92 |
+
DeviceBuffer& operator=(const DeviceBuffer&) = delete;
|
| 93 |
+
DeviceBuffer(DeviceBuffer&& o) noexcept : ptr(o.ptr), size(o.size) { o.ptr = nullptr; o.size = 0; }
|
| 94 |
+
DeviceBuffer& operator=(DeviceBuffer&& o) noexcept {
|
| 95 |
+
if (this != &o) { if (ptr) aclrtFree(ptr); ptr = o.ptr; size = o.size; o.ptr = nullptr; o.size = 0; }
|
| 96 |
+
return *this;
|
| 97 |
+
}
|
| 98 |
+
|
| 99 |
+
void alloc(size_t bytes) {
|
| 100 |
+
if (ptr) aclrtFree(ptr);
|
| 101 |
+
ACL_CHECK(aclrtMalloc(&ptr, bytes, ACL_MEM_MALLOC_HUGE_FIRST));
|
| 102 |
+
size = bytes;
|
| 103 |
+
}
|
| 104 |
+
void* get() { return ptr; }
|
| 105 |
+
const void* get() const { return ptr; }
|
| 106 |
+
};
|
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// acl_runtime.h — per-rank ACL runtime init/teardown.
|
| 2 |
+
#pragma once
|
| 3 |
+
#include "acl_common.h"
|
| 4 |
+
#include <cstdio>
|
| 5 |
+
|
| 6 |
+
class AclRuntime {
|
| 7 |
+
public:
|
| 8 |
+
AclRuntime() = default;
|
| 9 |
+
~AclRuntime() { shutdown(); }
|
| 10 |
+
|
| 11 |
+
bool init(int device_id) {
|
| 12 |
+
if (initialized_) return true;
|
| 13 |
+
device_id_ = device_id;
|
| 14 |
+
ACL_CHECK(aclInit(nullptr));
|
| 15 |
+
ACL_CHECK(aclrtSetDevice(device_id));
|
| 16 |
+
ACL_CHECK(aclrtCreateContext(&ctx_, device_id));
|
| 17 |
+
ACL_CHECK(aclrtCreateStream(&stream_));
|
| 18 |
+
initialized_ = true;
|
| 19 |
+
return true;
|
| 20 |
+
}
|
| 21 |
+
|
| 22 |
+
void shutdown() {
|
| 23 |
+
if (!initialized_) return;
|
| 24 |
+
if (stream_) { aclrtDestroyStream(stream_); stream_ = nullptr; }
|
| 25 |
+
if (ctx_) { aclrtDestroyContext(ctx_); ctx_ = nullptr; }
|
| 26 |
+
aclrtResetDevice(device_id_);
|
| 27 |
+
aclFinalize();
|
| 28 |
+
initialized_ = false;
|
| 29 |
+
}
|
| 30 |
+
|
| 31 |
+
void sync() { if (stream_) ACL_CHECK(aclrtSynchronizeStream(stream_)); }
|
| 32 |
+
|
| 33 |
+
aclrtStream stream() const { return stream_; }
|
| 34 |
+
int device_id() const { return device_id_; }
|
| 35 |
+
|
| 36 |
+
private:
|
| 37 |
+
bool initialized_ = false;
|
| 38 |
+
int device_id_ = 0;
|
| 39 |
+
aclrtContext ctx_ = nullptr;
|
| 40 |
+
aclrtStream stream_ = nullptr;
|
| 41 |
+
};
|
|
@@ -0,0 +1,345 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// aclnn_ops.h — thin wrappers around common aclnn operators used in forward pass.
|
| 2 |
+
// Each wrapper does GetWorkspaceSize + op call on the provided stream.
|
| 3 |
+
//
|
| 4 |
+
// All tensors are passed as raw aclTensor* (caller owns them).
|
| 5 |
+
// Workspace allocation uses DeviceBuffer (RAII).
|
| 6 |
+
#pragma once
|
| 7 |
+
#include "acl_common.h"
|
| 8 |
+
#include "workspace_pool.h"
|
| 9 |
+
|
| 10 |
+
// Thread-local shared workspace pool for all aclnn wrappers below. Single-threaded stream
|
| 11 |
+
// means we can safely reuse one buffer across serial op calls. Set via `GGML_CANN_WP=0` is
|
| 12 |
+
// not supported here — if truly needed, we'd wire a flag.
|
| 13 |
+
inline WorkspacePool& _lca_pool() {
|
| 14 |
+
thread_local WorkspacePool pool;
|
| 15 |
+
return pool;
|
| 16 |
+
}
|
| 17 |
+
|
| 18 |
+
#include <aclnnop/aclnn_add.h>
|
| 19 |
+
#include <aclnnop/aclnn_addcmul.h>
|
| 20 |
+
#include <aclnnop/aclnn_grouped_matmul_v4.h>
|
| 21 |
+
#include <aclnnop/aclnn_moe_finalize_routing.h>
|
| 22 |
+
#include <aclnnop/aclnn_moe_finalize_routing_v2.h>
|
| 23 |
+
#include <aclnnop/aclnn_moe_gating_top_k_softmax.h>
|
| 24 |
+
#include <aclnnop/aclnn_moe_init_routing_v3.h>
|
| 25 |
+
#include <aclnnop/aclnn_cast.h>
|
| 26 |
+
#include <aclnnop/aclnn_copy.h>
|
| 27 |
+
#include <aclnnop/aclnn_div.h>
|
| 28 |
+
#include <aclnnop/aclnn_fused_infer_attention_score.h>
|
| 29 |
+
#include <aclnnop/aclnn_index_select.h>
|
| 30 |
+
#include <aclnnop/aclnn_matmul.h>
|
| 31 |
+
#include <aclnnop/aclnn_mul.h>
|
| 32 |
+
#include <aclnnop/aclnn_neg.h>
|
| 33 |
+
#include <aclnnop/aclnn_reduce_sum.h>
|
| 34 |
+
#include <aclnnop/aclnn_silu.h>
|
| 35 |
+
|
| 36 |
+
// ---- RmsNorm ----
|
| 37 |
+
// Signature (based on ggml-cann usage): aclnnRmsNorm(x, gamma, eps, y, rstd)
|
| 38 |
+
// where rstd (rsqrt of mean-square) is an extra output we usually discard.
|
| 39 |
+
|
| 40 |
+
// Forward declare header; include happens in impl file to keep this header light.
|
| 41 |
+
extern "C" {
|
| 42 |
+
#include <aclnnop/aclnn_rms_norm.h>
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
inline void rms_norm(aclrtStream stream,
|
| 46 |
+
aclTensor* x, // [N, D] BF16/FP16
|
| 47 |
+
aclTensor* gamma, // [D] same dtype as x
|
| 48 |
+
double eps,
|
| 49 |
+
aclTensor* y, // [N, D]
|
| 50 |
+
aclTensor* rstd // [N] fp32 (required output)
|
| 51 |
+
) {
|
| 52 |
+
uint64_t ws = 0;
|
| 53 |
+
aclOpExecutor* exec = nullptr;
|
| 54 |
+
ACLNN_CHECK(aclnnRmsNormGetWorkspaceSize(x, gamma, eps, y, rstd, &ws, &exec));
|
| 55 |
+
void* wp = (ws > 0) ? _lca_pool().alloc(ws) : nullptr;
|
| 56 |
+
ACLNN_CHECK(aclnnRmsNorm(wp, ws, exec, stream));
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
+
// ---- Silu ----
|
| 60 |
+
inline void silu(aclrtStream stream, aclTensor* x, aclTensor* y) {
|
| 61 |
+
uint64_t ws = 0;
|
| 62 |
+
aclOpExecutor* exec = nullptr;
|
| 63 |
+
ACLNN_CHECK(aclnnSiluGetWorkspaceSize(x, y, &ws, &exec));
|
| 64 |
+
void* wp = (ws > 0) ? _lca_pool().alloc(ws) : nullptr;
|
| 65 |
+
ACLNN_CHECK(aclnnSilu(wp, ws, exec, stream));
|
| 66 |
+
}
|
| 67 |
+
|
| 68 |
+
// ---- Mul (element-wise) ----
|
| 69 |
+
inline void mul(aclrtStream stream, aclTensor* a, aclTensor* b, aclTensor* out) {
|
| 70 |
+
uint64_t ws = 0;
|
| 71 |
+
aclOpExecutor* exec = nullptr;
|
| 72 |
+
ACLNN_CHECK(aclnnMulGetWorkspaceSize(a, b, out, &ws, &exec));
|
| 73 |
+
void* wp = (ws > 0) ? _lca_pool().alloc(ws) : nullptr;
|
| 74 |
+
ACLNN_CHECK(aclnnMul(wp, ws, exec, stream));
|
| 75 |
+
}
|
| 76 |
+
|
| 77 |
+
// ---- Cast ----
|
| 78 |
+
inline void cast(aclrtStream stream, aclTensor* x, aclDataType dst_dtype, aclTensor* y) {
|
| 79 |
+
uint64_t ws = 0;
|
| 80 |
+
aclOpExecutor* exec = nullptr;
|
| 81 |
+
ACLNN_CHECK(aclnnCastGetWorkspaceSize(x, dst_dtype, y, &ws, &exec));
|
| 82 |
+
void* wp = (ws > 0) ? _lca_pool().alloc(ws) : nullptr;
|
| 83 |
+
ACLNN_CHECK(aclnnCast(wp, ws, exec, stream));
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
// ---- InplaceCopy: copy src (possibly non-contiguous via strides) into contiguous dst ----
|
| 87 |
+
inline void inplace_copy(aclrtStream stream, aclTensor* dst, aclTensor* src) {
|
| 88 |
+
uint64_t ws = 0;
|
| 89 |
+
aclOpExecutor* exec = nullptr;
|
| 90 |
+
ACLNN_CHECK(aclnnInplaceCopyGetWorkspaceSize(dst, src, &ws, &exec));
|
| 91 |
+
void* wp = (ws > 0) ? _lca_pool().alloc(ws) : nullptr;
|
| 92 |
+
ACLNN_CHECK(aclnnInplaceCopy(wp, ws, exec, stream));
|
| 93 |
+
}
|
| 94 |
+
|
| 95 |
+
// ---- Matmul: out = a @ b ----
|
| 96 |
+
// cube_math_type:
|
| 97 |
+
// 0 = KEEP_DTYPE, 1 = ALLOW_FP32_DOWN_PRECISION, 2 = USE_FP16, 3 = USE_HF32
|
| 98 |
+
inline void matmul(aclrtStream stream,
|
| 99 |
+
aclTensor* a, aclTensor* b, aclTensor* out,
|
| 100 |
+
int8_t cube_math_type = 1) {
|
| 101 |
+
uint64_t ws = 0;
|
| 102 |
+
aclOpExecutor* exec = nullptr;
|
| 103 |
+
ACLNN_CHECK(aclnnMatmulGetWorkspaceSize(a, b, out, cube_math_type, &ws, &exec));
|
| 104 |
+
void* wp = (ws > 0) ? _lca_pool().alloc(ws) : nullptr;
|
| 105 |
+
ACLNN_CHECK(aclnnMatmul(wp, ws, exec, stream));
|
| 106 |
+
}
|
| 107 |
+
|
| 108 |
+
// ---- Neg ----
|
| 109 |
+
inline void neg(aclrtStream stream, aclTensor* x, aclTensor* y) {
|
| 110 |
+
uint64_t ws = 0;
|
| 111 |
+
aclOpExecutor* exec = nullptr;
|
| 112 |
+
ACLNN_CHECK(aclnnNegGetWorkspaceSize(x, y, &ws, &exec));
|
| 113 |
+
void* wp = (ws > 0) ? _lca_pool().alloc(ws) : nullptr;
|
| 114 |
+
ACLNN_CHECK(aclnnNeg(wp, ws, exec, stream));
|
| 115 |
+
}
|
| 116 |
+
|
| 117 |
+
// ---- Addcmul: self = self + value * (tensor1 * tensor2) ----
|
| 118 |
+
inline void addcmul(aclrtStream stream, aclTensor* self_io, aclTensor* t1, aclTensor* t2, float value) {
|
| 119 |
+
aclScalar* v = aclCreateScalar(&value, ACL_FLOAT);
|
| 120 |
+
uint64_t ws = 0;
|
| 121 |
+
aclOpExecutor* exec = nullptr;
|
| 122 |
+
ACLNN_CHECK(aclnnAddcmulGetWorkspaceSize(self_io, t1, t2, v, self_io, &ws, &exec));
|
| 123 |
+
void* wp = (ws > 0) ? _lca_pool().alloc(ws) : nullptr;
|
| 124 |
+
ACLNN_CHECK(aclnnAddcmul(wp, ws, exec, stream));
|
| 125 |
+
aclDestroyScalar(v);
|
| 126 |
+
}
|
| 127 |
+
|
| 128 |
+
// ---- MoE Gating TopK Softmax ----
|
| 129 |
+
// x [N, E] → y [N, K] (top-K softmax probs), expert_idx [N, K] int32, row_idx [N, K] int32
|
| 130 |
+
inline void moe_gating_topk_softmax(aclrtStream stream,
|
| 131 |
+
aclTensor* x, int64_t k,
|
| 132 |
+
aclTensor* y_out, aclTensor* idx_out, aclTensor* row_idx_out) {
|
| 133 |
+
uint64_t ws = 0;
|
| 134 |
+
aclOpExecutor* exec = nullptr;
|
| 135 |
+
ACLNN_CHECK(aclnnMoeGatingTopKSoftmaxGetWorkspaceSize(x, nullptr, k, y_out, idx_out, row_idx_out, &ws, &exec));
|
| 136 |
+
void* wp = (ws > 0) ? _lca_pool().alloc(ws) : nullptr;
|
| 137 |
+
ACLNN_CHECK(aclnnMoeGatingTopKSoftmax(wp, ws, exec, stream));
|
| 138 |
+
}
|
| 139 |
+
|
| 140 |
+
// ---- MoE Init Routing V3 ----
|
| 141 |
+
// x [N, D], expert_idx [N, K] int32 → expanded_x [N*K, D], expanded_row_idx [N*K] int32,
|
| 142 |
+
// tokens_per_expert [E] int64
|
| 143 |
+
inline void moe_init_routing_v3(aclrtStream stream,
|
| 144 |
+
aclTensor* x, aclTensor* expert_idx,
|
| 145 |
+
int64_t n_experts, int64_t active_num,
|
| 146 |
+
aclTensor* expanded_x, aclTensor* expanded_row_idx,
|
| 147 |
+
aclTensor* tokens_per_expert)
|
| 148 |
+
{
|
| 149 |
+
int64_t range[2] = {0, n_experts};
|
| 150 |
+
aclIntArray* r = aclCreateIntArray(range, 2);
|
| 151 |
+
// scale_out_optional we dummy since quant_mode=-1 (no quant) still requires pass a placeholder?
|
| 152 |
+
// Per our POC test earlier: pass a real tensor for scale_out works.
|
| 153 |
+
// For simplicity here, we'll allocate a dummy [active_num] float tensor.
|
| 154 |
+
DeviceBuffer dummy(active_num * 4);
|
| 155 |
+
auto t_dummy = make_contig_tensor(dummy.get(), ACL_FLOAT, {active_num});
|
| 156 |
+
|
| 157 |
+
uint64_t ws = 0; aclOpExecutor* exec = nullptr;
|
| 158 |
+
// rowIdxType=1: expanded_row_idx[i] = sorted_position p for i-th original (n,k) flat index.
|
| 159 |
+
// This lets us use expanded_row_idx directly as the gather index (forward permutation).
|
| 160 |
+
ACLNN_CHECK(aclnnMoeInitRoutingV3GetWorkspaceSize(
|
| 161 |
+
x, expert_idx, nullptr, nullptr,
|
| 162 |
+
active_num, 0, n_experts, 0, 1, true, -1,
|
| 163 |
+
r, 1,
|
| 164 |
+
expanded_x, expanded_row_idx, tokens_per_expert, t_dummy.get(),
|
| 165 |
+
&ws, &exec));
|
| 166 |
+
void* wp = (ws > 0) ? _lca_pool().alloc(ws) : nullptr;
|
| 167 |
+
ACLNN_CHECK(aclnnMoeInitRoutingV3(wp, ws, exec, stream));
|
| 168 |
+
aclDestroyIntArray(r);
|
| 169 |
+
}
|
| 170 |
+
|
| 171 |
+
// ---- GroupedMatmulV4 (single-in single-out, M-axis split) ----
|
| 172 |
+
// x [T, K_in], w [E, K_in, N_out] contiguous row-major, group_list [E] int64 → y [T, N_out]
|
| 173 |
+
// group_list_type: 0=cumsum, 1=counts (V4 doc)
|
| 174 |
+
inline void grouped_matmul_v4(aclrtStream stream,
|
| 175 |
+
aclTensor* x, aclTensor* w, aclTensor* group_list, aclTensor* y,
|
| 176 |
+
int64_t group_list_type = 1)
|
| 177 |
+
{
|
| 178 |
+
aclTensor* xa[] = {x}; aclTensorList* x_list = aclCreateTensorList(xa, 1);
|
| 179 |
+
aclTensor* wa[] = {w}; aclTensorList* w_list = aclCreateTensorList(wa, 1);
|
| 180 |
+
aclTensor* ya[] = {y}; aclTensorList* y_list = aclCreateTensorList(ya, 1);
|
| 181 |
+
|
| 182 |
+
uint64_t ws = 0; aclOpExecutor* exec = nullptr;
|
| 183 |
+
ACLNN_CHECK(aclnnGroupedMatmulV4GetWorkspaceSize(
|
| 184 |
+
x_list, w_list,
|
| 185 |
+
nullptr, nullptr, nullptr, nullptr, nullptr, nullptr,
|
| 186 |
+
group_list,
|
| 187 |
+
nullptr, nullptr, nullptr,
|
| 188 |
+
3, 0, group_list_type, 0,
|
| 189 |
+
y_list, nullptr, nullptr,
|
| 190 |
+
&ws, &exec));
|
| 191 |
+
void* wp = (ws > 0) ? _lca_pool().alloc(ws) : nullptr;
|
| 192 |
+
ACLNN_CHECK(aclnnGroupedMatmulV4(wp, ws, exec, stream));
|
| 193 |
+
// NOTE: TensorList takes ownership of the raw tensors. Destroying the list frees them,
|
| 194 |
+
// which would cause double-free in the caller's AclTensorPtr. Leak the list (small cost).
|
| 195 |
+
// A cleaner API would accept (ptr, shape, dtype) triples and build tensors internally.
|
| 196 |
+
// TODO(M6): refactor for long-running use.
|
| 197 |
+
}
|
| 198 |
+
|
| 199 |
+
// ---- MoE Finalize Routing V2: out = x1 + weighted_sum of top-K outputs ----
|
| 200 |
+
// V2 has all inputs optional except expandedX/expandedRowIdx/out; pass nullptr for x1 to
|
| 201 |
+
// skip the residual add, or pass the residual to fuse it into this op.
|
| 202 |
+
inline void moe_finalize_routing(aclrtStream stream,
|
| 203 |
+
aclTensor* expanded_x,
|
| 204 |
+
aclTensor* x1_skip, // [N, D] added to output (nullable)
|
| 205 |
+
aclTensor* scales, // weights [N, K]
|
| 206 |
+
aclTensor* expanded_row_idx,
|
| 207 |
+
aclTensor* expert_idx, // [N, K] topk expert indices (nullable)
|
| 208 |
+
aclTensor* out)
|
| 209 |
+
{
|
| 210 |
+
uint64_t ws = 0; aclOpExecutor* exec = nullptr;
|
| 211 |
+
ACLNN_CHECK(aclnnMoeFinalizeRoutingV2GetWorkspaceSize(
|
| 212 |
+
expanded_x,
|
| 213 |
+
expanded_row_idx,
|
| 214 |
+
x1_skip, // x1Optional
|
| 215 |
+
nullptr, // x2Optional
|
| 216 |
+
nullptr, // biasOptional
|
| 217 |
+
scales, // scalesOptional
|
| 218 |
+
expert_idx, // expertIdxOptional (needed for correct routing)
|
| 219 |
+
0, // dropPadMode (0 = dropless, which matches our pipeline)
|
| 220 |
+
out,
|
| 221 |
+
&ws, &exec));
|
| 222 |
+
void* wp = (ws > 0) ? _lca_pool().alloc(ws) : nullptr;
|
| 223 |
+
ACLNN_CHECK(aclnnMoeFinalizeRoutingV2(wp, ws, exec, stream));
|
| 224 |
+
}
|
| 225 |
+
|
| 226 |
+
// ---- Div: self / other (broadcast supported) ----
|
| 227 |
+
inline void div_tensor(aclrtStream stream, aclTensor* self, aclTensor* other, aclTensor* out) {
|
| 228 |
+
uint64_t ws = 0; aclOpExecutor* exec = nullptr;
|
| 229 |
+
ACLNN_CHECK(aclnnDivGetWorkspaceSize(self, other, out, &ws, &exec));
|
| 230 |
+
void* wp = (ws > 0) ? _lca_pool().alloc(ws) : nullptr;
|
| 231 |
+
ACLNN_CHECK(aclnnDiv(wp, ws, exec, stream));
|
| 232 |
+
}
|
| 233 |
+
|
| 234 |
+
// ---- In-place scalar add: self += scalar ----
|
| 235 |
+
#include <aclnnop/aclnn_add.h>
|
| 236 |
+
#include <aclnnop/aclnn_argsort.h>
|
| 237 |
+
|
| 238 |
+
// ---- Argsort: indices that would sort self along dim (returns INT64) ----
|
| 239 |
+
inline void argsort(aclrtStream stream, aclTensor* self, int64_t dim, bool descending,
|
| 240 |
+
aclTensor* indices_out) {
|
| 241 |
+
uint64_t ws = 0; aclOpExecutor* exec = nullptr;
|
| 242 |
+
ACLNN_CHECK(aclnnArgsortGetWorkspaceSize(self, dim, descending, indices_out, &ws, &exec));
|
| 243 |
+
void* wp = (ws > 0) ? _lca_pool().alloc(ws) : nullptr;
|
| 244 |
+
ACLNN_CHECK(aclnnArgsort(wp, ws, exec, stream));
|
| 245 |
+
}
|
| 246 |
+
|
| 247 |
+
inline void inplace_adds(aclrtStream stream, aclTensor* self, double value) {
|
| 248 |
+
float v = (float)value;
|
| 249 |
+
aclScalar* s = aclCreateScalar(&v, ACL_FLOAT);
|
| 250 |
+
float alpha_v = 1.0f;
|
| 251 |
+
aclScalar* al = aclCreateScalar(&alpha_v, ACL_FLOAT);
|
| 252 |
+
uint64_t ws = 0; aclOpExecutor* exec = nullptr;
|
| 253 |
+
ACLNN_CHECK(aclnnInplaceAddsGetWorkspaceSize(self, s, al, &ws, &exec));
|
| 254 |
+
void* wp = (ws > 0) ? _lca_pool().alloc(ws) : nullptr;
|
| 255 |
+
ACLNN_CHECK(aclnnInplaceAdds(wp, ws, exec, stream));
|
| 256 |
+
aclDestroyScalar(s);
|
| 257 |
+
aclDestroyScalar(al);
|
| 258 |
+
}
|
| 259 |
+
|
| 260 |
+
// ---- ReduceSum over specified dims ----
|
| 261 |
+
inline void reduce_sum(aclrtStream stream, aclTensor* self, const std::vector<int64_t>& dims,
|
| 262 |
+
bool keep_dims, aclDataType out_dtype, aclTensor* out) {
|
| 263 |
+
aclIntArray* d = aclCreateIntArray(dims.data(), dims.size());
|
| 264 |
+
uint64_t ws = 0; aclOpExecutor* exec = nullptr;
|
| 265 |
+
ACLNN_CHECK(aclnnReduceSumGetWorkspaceSize(self, d, keep_dims, out_dtype, out, &ws, &exec));
|
| 266 |
+
void* wp = (ws > 0) ? _lca_pool().alloc(ws) : nullptr;
|
| 267 |
+
ACLNN_CHECK(aclnnReduceSum(wp, ws, exec, stream));
|
| 268 |
+
aclDestroyIntArray(d);
|
| 269 |
+
}
|
| 270 |
+
|
| 271 |
+
// ---- IndexSelect: out[j] = self[index[j], ...] ----
|
| 272 |
+
inline void index_select(aclrtStream stream, aclTensor* self, int64_t dim, aclTensor* index, aclTensor* out) {
|
| 273 |
+
uint64_t ws = 0;
|
| 274 |
+
aclOpExecutor* exec = nullptr;
|
| 275 |
+
ACLNN_CHECK(aclnnIndexSelectGetWorkspaceSize(self, dim, index, out, &ws, &exec));
|
| 276 |
+
void* wp = (ws > 0) ? _lca_pool().alloc(ws) : nullptr;
|
| 277 |
+
ACLNN_CHECK(aclnnIndexSelect(wp, ws, exec, stream));
|
| 278 |
+
}
|
| 279 |
+
|
| 280 |
+
// ---- FusedInferAttentionScore (simplified wrapper for prefill/decode without quant, BSH layout).
|
| 281 |
+
// Caller owns q/k/v/mask/out; k/v are single-tensor lists.
|
| 282 |
+
inline void fused_infer_attention_score(
|
| 283 |
+
aclrtStream stream,
|
| 284 |
+
aclTensor* q, // [B, S, Hq*Dh] BF16
|
| 285 |
+
aclTensor* k, // [B, S, Hkv*Dh] BF16
|
| 286 |
+
aclTensor* v, // [B, S, Hkv*Dh] BF16
|
| 287 |
+
aclTensor* atten_mask, // [1, 1, M, M] bool, sparse_mode=3 needs M=2048
|
| 288 |
+
std::vector<int64_t> actual_seq_lens,
|
| 289 |
+
std::vector<int64_t> actual_seq_lens_kv,
|
| 290 |
+
int64_t num_heads, int64_t num_kv_heads,
|
| 291 |
+
double scale, int64_t sparse_mode,
|
| 292 |
+
aclTensor* out) // [B, S, Hq*Dh]
|
| 293 |
+
{
|
| 294 |
+
aclTensor* k_arr[] = {k};
|
| 295 |
+
aclTensor* v_arr[] = {v};
|
| 296 |
+
aclTensorList* k_list = aclCreateTensorList(k_arr, 1);
|
| 297 |
+
aclTensorList* v_list = aclCreateTensorList(v_arr, 1);
|
| 298 |
+
aclIntArray* sq = aclCreateIntArray(actual_seq_lens.data(), (uint64_t)actual_seq_lens.size());
|
| 299 |
+
aclIntArray* skv = aclCreateIntArray(actual_seq_lens_kv.data(), (uint64_t)actual_seq_lens_kv.size());
|
| 300 |
+
|
| 301 |
+
uint64_t ws = 0;
|
| 302 |
+
aclOpExecutor* exec = nullptr;
|
| 303 |
+
ACLNN_CHECK(aclnnFusedInferAttentionScoreGetWorkspaceSize(
|
| 304 |
+
q, k_list, v_list,
|
| 305 |
+
nullptr, // pseShift
|
| 306 |
+
atten_mask,
|
| 307 |
+
sq, skv,
|
| 308 |
+
nullptr, nullptr, nullptr, nullptr, nullptr, // dequant/quant scales
|
| 309 |
+
nullptr, nullptr, // antiquant
|
| 310 |
+
nullptr, nullptr, nullptr, // block_table, q_padding, kv_padding
|
| 311 |
+
num_heads,
|
| 312 |
+
scale,
|
| 313 |
+
2147483647, 2147483647, // pre/next tokens (no limit)
|
| 314 |
+
(char*)"BSH",
|
| 315 |
+
num_kv_heads,
|
| 316 |
+
sparse_mode,
|
| 317 |
+
0, // inner_precise
|
| 318 |
+
0, 0, // block_size, antiquant_mode
|
| 319 |
+
false, // softmax_lse_flag
|
| 320 |
+
out, nullptr,
|
| 321 |
+
&ws, &exec));
|
| 322 |
+
|
| 323 |
+
void* wp = (ws > 0) ? _lca_pool().alloc(ws) : nullptr;
|
| 324 |
+
ACLNN_CHECK(aclnnFusedInferAttentionScore(wp, ws, exec, stream));
|
| 325 |
+
// See note on grouped_matmul_v4 — intentionally leak lists to avoid double-free with caller RAII.
|
| 326 |
+
(void)k_list; (void)v_list;
|
| 327 |
+
aclDestroyIntArray(sq);
|
| 328 |
+
aclDestroyIntArray(skv);
|
| 329 |
+
}
|
| 330 |
+
|
| 331 |
+
// ---- "Linear" helper: y = x @ W.T where W is stored as [out_features, in_features] (HF convention).
|
| 332 |
+
// Achieved by viewing W as [in_features, out_features] with stride [1, in_features] (elements).
|
| 333 |
+
// Returns y [N, out_features].
|
| 334 |
+
// Caller allocates y.
|
| 335 |
+
inline void linear_hf(aclrtStream stream,
|
| 336 |
+
aclTensor* x, // [N, in_features]
|
| 337 |
+
void* W_data, aclDataType dtype,
|
| 338 |
+
int64_t out_features, int64_t in_features,
|
| 339 |
+
aclTensor* y_out) // [N, out_features]
|
| 340 |
+
{
|
| 341 |
+
auto W_view = make_acl_tensor(W_data, dtype,
|
| 342 |
+
{in_features, out_features},
|
| 343 |
+
{1, in_features}); // strides: d0=1 elem, d1=in_features elems
|
| 344 |
+
matmul(stream, x, W_view.get(), y_out);
|
| 345 |
+
}
|
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// device_weights.h — load safetensors weights to device memory with proper TP shard.
|
| 2 |
+
//
|
| 3 |
+
// For M3 (attention only): loads attention + norm weights. MoE expert weights come in M4.
|
| 4 |
+
//
|
| 5 |
+
#pragma once
|
| 6 |
+
#include "acl_common.h"
|
| 7 |
+
#include "model_config.h"
|
| 8 |
+
#include "safetensors_loader.h"
|
| 9 |
+
|
| 10 |
+
#include <string>
|
| 11 |
+
#include <unordered_map>
|
| 12 |
+
#include <vector>
|
| 13 |
+
|
| 14 |
+
// Per-layer MoE weights on device (BF16).
|
| 15 |
+
// After loading: weights are in GMM-ready layout [E, K_in, N_out] row-major contiguous.
|
| 16 |
+
// For gate/up: K_in=D, N_out=I_per_rank
|
| 17 |
+
// For down: K_in=I, N_out=D
|
| 18 |
+
struct LayerMoEWeights {
|
| 19 |
+
DeviceBuffer router; // [E, D] BF16 replicated
|
| 20 |
+
DeviceBuffer gate_exps; // [E, D, I_per_rank] (permuted from HF [E, I, D])
|
| 21 |
+
DeviceBuffer up_exps; // [E, D, I_per_rank]
|
| 22 |
+
DeviceBuffer down_exps; // [E, I_per_rank, D] (permuted from HF [E, D, I])
|
| 23 |
+
};
|
| 24 |
+
|
| 25 |
+
// Per-layer attention weights on device (BF16 unless noted).
|
| 26 |
+
struct LayerAttnWeights {
|
| 27 |
+
DeviceBuffer input_layernorm; // [D] BF16
|
| 28 |
+
DeviceBuffer post_attention_layernorm; // [D] BF16
|
| 29 |
+
// Q/K/V/O projections. HF stores as [out, in] BF16.
|
| 30 |
+
// For M3 we keep HF layout as-is; matmul wrappers handle the transpose via aclnnMm semantics.
|
| 31 |
+
DeviceBuffer q_proj; // [Q_full, D] on rank, but physical stored as [Q_rank, D] (sliced by head)
|
| 32 |
+
DeviceBuffer k_proj; // [KV, D] (replicated if tp_size > num_kv_heads)
|
| 33 |
+
DeviceBuffer v_proj; // [KV, D]
|
| 34 |
+
DeviceBuffer o_proj; // [D, Q_rank] (row-parallel on Q dim)
|
| 35 |
+
DeviceBuffer q_norm; // [head_dim] BF16 (Qwen3 per-head norm)
|
| 36 |
+
DeviceBuffer k_norm; // [head_dim] BF16
|
| 37 |
+
};
|
| 38 |
+
|
| 39 |
+
// Shared model weights (replicated across ranks).
|
| 40 |
+
struct SharedWeights {
|
| 41 |
+
DeviceBuffer embed_tokens; // [vocab, D]
|
| 42 |
+
DeviceBuffer lm_head; // [vocab, D]
|
| 43 |
+
DeviceBuffer final_norm; // [D]
|
| 44 |
+
};
|
| 45 |
+
|
| 46 |
+
class DeviceWeightsLoader {
|
| 47 |
+
public:
|
| 48 |
+
DeviceWeightsLoader(SafetensorsLoader& st, const ModelConfig& cfg)
|
| 49 |
+
: st_(st), cfg_(cfg) {}
|
| 50 |
+
|
| 51 |
+
// Load shared (embed, norm, lm_head). Replicated on every rank.
|
| 52 |
+
bool load_shared(SharedWeights& out);
|
| 53 |
+
|
| 54 |
+
// Load ONE attention layer's weights with TP sharding.
|
| 55 |
+
bool load_attention(int layer_idx, LayerAttnWeights& out);
|
| 56 |
+
|
| 57 |
+
// Load ONE MoE layer's weights. Stacks 128 experts and permutes to GMM-ready layout.
|
| 58 |
+
// stream: ACL stream for the permute op (aclnnInplaceCopy).
|
| 59 |
+
bool load_moe(int layer_idx, aclrtStream stream, LayerMoEWeights& out);
|
| 60 |
+
|
| 61 |
+
// Expose underlying safetensors for direct access (diagnostic use).
|
| 62 |
+
SafetensorsLoader& st() { return st_; }
|
| 63 |
+
|
| 64 |
+
private:
|
| 65 |
+
SafetensorsLoader& st_;
|
| 66 |
+
const ModelConfig& cfg_;
|
| 67 |
+
|
| 68 |
+
// Helper: load HF tensor (full shape) into device buffer (simple H2D).
|
| 69 |
+
bool load_tensor_full_(const std::string& name, DeviceBuffer& buf);
|
| 70 |
+
|
| 71 |
+
// Helper: load HF tensor and keep only [row_lo, row_hi) of first dim (TP shard by "out" dim).
|
| 72 |
+
// HF format: tensor has shape [D0, D1, ...] stored row-major. We take rows [lo, hi) to form
|
| 73 |
+
// a sharded tensor of shape [hi-lo, D1, ...].
|
| 74 |
+
bool load_tensor_row_slice_(const std::string& name,
|
| 75 |
+
int64_t row_lo, int64_t row_hi,
|
| 76 |
+
DeviceBuffer& buf);
|
| 77 |
+
|
| 78 |
+
// TP shard by "in" dim (second axis for 2D, etc.) — used for o_proj (row-parallel).
|
| 79 |
+
bool load_tensor_col_slice_(const std::string& name,
|
| 80 |
+
int64_t col_lo, int64_t col_hi,
|
| 81 |
+
DeviceBuffer& buf);
|
| 82 |
+
};
|
|
@@ -0,0 +1,354 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// engine.h — single-layer forward functions for attention and MoE.
|
| 2 |
+
//
|
| 3 |
+
// Both functions operate on device tensors. The caller owns all buffers (input, output, weights,
|
| 4 |
+
// KV cache slots, scratch). They take RoPE cos/sin tables and act as pure forward kernels.
|
| 5 |
+
//
|
| 6 |
+
// Design goals:
|
| 7 |
+
// - Zero allocations per call (all scratch is passed in)
|
| 8 |
+
// - Same signature works for prefill (S>=1) and decode (S=1); caller picks sparse_mode.
|
| 9 |
+
// - Residual connection is NOT included (caller decides when to add residual).
|
| 10 |
+
#pragma once
|
| 11 |
+
#include "acl_common.h"
|
| 12 |
+
#include "aclnn_ops.h"
|
| 13 |
+
#include "device_weights.h"
|
| 14 |
+
#include "hccl_comm.h"
|
| 15 |
+
#include "model_config.h"
|
| 16 |
+
#include "rope.h"
|
| 17 |
+
|
| 18 |
+
#include <algorithm>
|
| 19 |
+
#include <cmath>
|
| 20 |
+
#include <cstring>
|
| 21 |
+
#include <tuple>
|
| 22 |
+
#include <vector>
|
| 23 |
+
|
| 24 |
+
// Bf16 conversion helpers used by fill_cos_sin.
|
| 25 |
+
static inline uint16_t _engine_f2bf16(float x) {
|
| 26 |
+
uint32_t u; std::memcpy(&u, &x, 4);
|
| 27 |
+
return (uint16_t)((u + 0x7FFF + ((u >> 16) & 1)) >> 16);
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
// Fill cos/sin tables for positions [p0, p0+L) with HF half-half layout. Returns
|
| 31 |
+
// contiguous [L*Dh] BF16 in provided host vectors (caller uploads to device).
|
| 32 |
+
inline void fill_cos_sin_hf(std::vector<uint16_t>& cos_h, std::vector<uint16_t>& sin_h,
|
| 33 |
+
int64_t p0, int64_t L, int64_t Dh, float theta) {
|
| 34 |
+
cos_h.resize(L * Dh);
|
| 35 |
+
sin_h.resize(L * Dh);
|
| 36 |
+
int64_t half = Dh / 2;
|
| 37 |
+
for (int64_t s = 0; s < L; s++) {
|
| 38 |
+
for (int64_t d = 0; d < Dh; d++) {
|
| 39 |
+
int64_t pair = (d < half) ? d : (d - half);
|
| 40 |
+
float theta_pair = 1.0f / std::pow(theta, (2.0f * pair) / Dh);
|
| 41 |
+
float angle = (float)(p0 + s) * theta_pair;
|
| 42 |
+
cos_h[s * Dh + d] = _engine_f2bf16(std::cos(angle));
|
| 43 |
+
sin_h[s * Dh + d] = _engine_f2bf16(std::sin(angle));
|
| 44 |
+
}
|
| 45 |
+
}
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
// Precomputed RoPE cos/sin table: BF16 [max_seq, Dh]. One-time cost per runtime.
|
| 49 |
+
struct RopeCache {
|
| 50 |
+
DeviceBuffer cos; // [max_seq, Dh] BF16
|
| 51 |
+
DeviceBuffer sin; // [max_seq, Dh] BF16
|
| 52 |
+
int64_t max_seq = 0;
|
| 53 |
+
int64_t head_dim = 0;
|
| 54 |
+
float theta = 0.0f;
|
| 55 |
+
};
|
| 56 |
+
|
| 57 |
+
inline bool rope_cache_build(RopeCache& rc, int64_t max_seq, int64_t head_dim, float theta) {
|
| 58 |
+
std::vector<uint16_t> cos_h, sin_h;
|
| 59 |
+
fill_cos_sin_hf(cos_h, sin_h, /*p0=*/0, max_seq, head_dim, theta);
|
| 60 |
+
rc.cos.alloc(max_seq * head_dim * 2);
|
| 61 |
+
rc.sin.alloc(max_seq * head_dim * 2);
|
| 62 |
+
ACL_CHECK(aclrtMemcpy(rc.cos.get(), cos_h.size() * 2, cos_h.data(), cos_h.size() * 2, ACL_MEMCPY_HOST_TO_DEVICE));
|
| 63 |
+
ACL_CHECK(aclrtMemcpy(rc.sin.get(), sin_h.size() * 2, sin_h.data(), sin_h.size() * 2, ACL_MEMCPY_HOST_TO_DEVICE));
|
| 64 |
+
rc.max_seq = max_seq; rc.head_dim = head_dim; rc.theta = theta;
|
| 65 |
+
return true;
|
| 66 |
+
}
|
| 67 |
+
|
| 68 |
+
// Attention forward for a single layer.
|
| 69 |
+
//
|
| 70 |
+
// x_in [S, D] (hidden state, pre input_layernorm)
|
| 71 |
+
// x_out [S, D] (attention output — NOT residual-added)
|
| 72 |
+
//
|
| 73 |
+
// K cache / V cache are contiguous [MAX_LEN, KV_DIM] BF16 buffers. This call writes new
|
| 74 |
+
// positions at [past_len, past_len+S) and then runs FIAS over [0, past_len+S).
|
| 75 |
+
//
|
| 76 |
+
// Scratch requirements:
|
| 77 |
+
// q_scratch: S * Q_DIM * 2 bytes
|
| 78 |
+
// k_scratch: S * KV_DIM * 2 bytes
|
| 79 |
+
// v_scratch: S * KV_DIM * 2 bytes
|
| 80 |
+
// xn_scratch: S * D * 2 bytes
|
| 81 |
+
// rstd_scratch: S * 4 bytes (RmsNorm rstd output)
|
| 82 |
+
// rope_scratch: S * Hq * Dh * 2 bytes
|
| 83 |
+
//
|
| 84 |
+
// mask: [1, 1, 2048, 2048] bool for prefill; ignored (pass nullptr) for decode.
|
| 85 |
+
inline void attention_forward(
|
| 86 |
+
aclrtStream stream,
|
| 87 |
+
const ModelConfig& cfg,
|
| 88 |
+
LayerAttnWeights& w,
|
| 89 |
+
void* x_in, // [S, D] BF16
|
| 90 |
+
int64_t S,
|
| 91 |
+
int64_t past_len, // prior KV positions
|
| 92 |
+
void* k_cache, void* v_cache, int64_t max_len,
|
| 93 |
+
aclTensor* mask_tensor, // may be nullptr for decode
|
| 94 |
+
void* q_scratch, void* k_scratch, void* v_scratch,
|
| 95 |
+
void* xn_scratch, void* rstd_scratch, void* rope_scratch,
|
| 96 |
+
void* attn_out_scratch, // S * Q_DIM * 2 bytes (FIAS output before o_proj)
|
| 97 |
+
void* x_out, // [S, D] BF16
|
| 98 |
+
HcclCtx* hccl_ctx = nullptr, // if tp_size > 1, AllReduce x_out after o_proj
|
| 99 |
+
const RopeCache* rope_cache = nullptr, // if provided, use cached cos/sin table; avoids per-call H2D
|
| 100 |
+
int64_t sparse_mode = -1 // -1=auto (3 for prefill, 0 for decode); explicit 0/3 overrides
|
| 101 |
+
) {
|
| 102 |
+
const int64_t D = cfg.hidden_size;
|
| 103 |
+
const int64_t Hq = cfg.n_heads_per_rank;
|
| 104 |
+
const int64_t Hkv = cfg.n_kv_heads_per_rank;
|
| 105 |
+
const int64_t Dh = cfg.head_dim;
|
| 106 |
+
const int64_t Q_DIM = Hq * Dh;
|
| 107 |
+
const int64_t KV_DIM = Hkv * Dh;
|
| 108 |
+
const double scale = 1.0 / std::sqrt((double)Dh);
|
| 109 |
+
const double eps = cfg.rms_norm_eps;
|
| 110 |
+
const float theta = cfg.rope_theta;
|
| 111 |
+
|
| 112 |
+
// 1. Input layernorm: xn = rmsnorm(x_in, input_layernorm_weight)
|
| 113 |
+
auto t_x = make_contig_tensor(x_in, ACL_BF16, {S, D});
|
| 114 |
+
auto t_xn = make_contig_tensor(xn_scratch, ACL_BF16, {S, D});
|
| 115 |
+
auto t_lnw = make_contig_tensor(w.input_layernorm.get(), ACL_BF16, {D});
|
| 116 |
+
auto t_rstd = make_contig_tensor(rstd_scratch, ACL_FLOAT, {S});
|
| 117 |
+
rms_norm(stream, t_x.get(), t_lnw.get(), eps, t_xn.get(), t_rstd.get());
|
| 118 |
+
|
| 119 |
+
// 2. Q/K/V projection
|
| 120 |
+
auto t_q = make_contig_tensor(q_scratch, ACL_BF16, {S, Q_DIM});
|
| 121 |
+
auto t_k = make_contig_tensor(k_scratch, ACL_BF16, {S, KV_DIM});
|
| 122 |
+
auto t_v = make_contig_tensor(v_scratch, ACL_BF16, {S, KV_DIM});
|
| 123 |
+
linear_hf(stream, t_xn.get(), w.q_proj.get(), ACL_BF16, Q_DIM, D, t_q.get());
|
| 124 |
+
linear_hf(stream, t_xn.get(), w.k_proj.get(), ACL_BF16, KV_DIM, D, t_k.get());
|
| 125 |
+
linear_hf(stream, t_xn.get(), w.v_proj.get(), ACL_BF16, KV_DIM, D, t_v.get());
|
| 126 |
+
|
| 127 |
+
// 3. Per-head q_norm, k_norm
|
| 128 |
+
auto t_q_4d = make_contig_tensor(q_scratch, ACL_BF16, {1, S, Hq, Dh});
|
| 129 |
+
auto t_k_4d = make_contig_tensor(k_scratch, ACL_BF16, {1, S, Hkv, Dh});
|
| 130 |
+
auto t_qn_w = make_contig_tensor(w.q_norm.get(), ACL_BF16, {Dh});
|
| 131 |
+
auto t_kn_w = make_contig_tensor(w.k_norm.get(), ACL_BF16, {Dh});
|
| 132 |
+
// reuse rstd_scratch split or allocate? reuse xn_scratch's first S*Hq*4 bytes.
|
| 133 |
+
// Simpler: require rstd_scratch to have max(S, S*max(Hq,Hkv)) * 4 bytes.
|
| 134 |
+
// For single-rank attention tests we pass enough.
|
| 135 |
+
auto t_rstd_q = make_contig_tensor(rstd_scratch, ACL_FLOAT, {1, S, Hq});
|
| 136 |
+
auto t_rstd_k = make_contig_tensor(rstd_scratch, ACL_FLOAT, {1, S, Hkv});
|
| 137 |
+
rms_norm(stream, t_q_4d.get(), t_qn_w.get(), eps, t_q_4d.get(), t_rstd_q.get());
|
| 138 |
+
rms_norm(stream, t_k_4d.get(), t_kn_w.get(), eps, t_k_4d.get(), t_rstd_k.get());
|
| 139 |
+
|
| 140 |
+
// 4. RoPE: positions [past_len, past_len + S). Fused aclnnApplyRotaryPosEmbV2 is 1 op
|
| 141 |
+
// vs 8-op manual version — saves ~7 kernel launches/layer × 94 layers = 658/token.
|
| 142 |
+
if (rope_cache && rope_cache->cos.get() && past_len + S <= rope_cache->max_seq) {
|
| 143 |
+
void* cos_ptr = (char*)rope_cache->cos.get() + past_len * Dh * 2;
|
| 144 |
+
void* sin_ptr = (char*)rope_cache->sin.get() + past_len * Dh * 2;
|
| 145 |
+
apply_rope_fused(stream, q_scratch, 1, S, Hq, Dh, k_scratch, Hkv, cos_ptr, sin_ptr);
|
| 146 |
+
} else {
|
| 147 |
+
std::vector<uint16_t> cos_h, sin_h;
|
| 148 |
+
fill_cos_sin_hf(cos_h, sin_h, past_len, S, Dh, theta);
|
| 149 |
+
DeviceBuffer cos_dev(S * Dh * 2), sin_dev(S * Dh * 2);
|
| 150 |
+
ACL_CHECK(aclrtMemcpy(cos_dev.get(), S*Dh*2, cos_h.data(), S*Dh*2, ACL_MEMCPY_HOST_TO_DEVICE));
|
| 151 |
+
ACL_CHECK(aclrtMemcpy(sin_dev.get(), S*Dh*2, sin_h.data(), S*Dh*2, ACL_MEMCPY_HOST_TO_DEVICE));
|
| 152 |
+
apply_rope_manual(stream, q_scratch, 1, S, Hq, Dh, k_scratch, Hkv,
|
| 153 |
+
cos_dev.get(), sin_dev.get(), rope_scratch);
|
| 154 |
+
// Local DeviceBuffers would be freed on return while async kernels still read them.
|
| 155 |
+
ACL_CHECK(aclrtSynchronizeStream(stream));
|
| 156 |
+
}
|
| 157 |
+
|
| 158 |
+
// 5. Append K, V to cache at [past_len, past_len + S)
|
| 159 |
+
ACL_CHECK(aclrtMemcpyAsync((char*)k_cache + past_len * KV_DIM * 2, S * KV_DIM * 2,
|
| 160 |
+
k_scratch, S * KV_DIM * 2, ACL_MEMCPY_DEVICE_TO_DEVICE, stream));
|
| 161 |
+
ACL_CHECK(aclrtMemcpyAsync((char*)v_cache + past_len * KV_DIM * 2, S * KV_DIM * 2,
|
| 162 |
+
v_scratch, S * KV_DIM * 2, ACL_MEMCPY_DEVICE_TO_DEVICE, stream));
|
| 163 |
+
|
| 164 |
+
// 6. FIAS: q [1, S, Q_DIM], k/v [1, kv_len, KV_DIM] from cache
|
| 165 |
+
int64_t kv_len = past_len + S;
|
| 166 |
+
auto t_q_bsh = make_contig_tensor(q_scratch, ACL_BF16, {1, S, Q_DIM});
|
| 167 |
+
auto t_k_bsh = make_contig_tensor(k_cache, ACL_BF16, {1, kv_len, KV_DIM});
|
| 168 |
+
auto t_v_bsh = make_contig_tensor(v_cache, ACL_BF16, {1, kv_len, KV_DIM});
|
| 169 |
+
// FIAS writes to a separate buffer (attn_out_scratch) — aliasing q→out is unsafe.
|
| 170 |
+
auto t_attn_out_bsh = make_contig_tensor(attn_out_scratch, ACL_BF16, {1, S, Q_DIM});
|
| 171 |
+
// sparse_mode selection:
|
| 172 |
+
// 3 = left-top causal (prefill, q.S == kv.S with 2048 mask)
|
| 173 |
+
// 0 = user mask (decode with cache, batch verify)
|
| 174 |
+
// -1 (sentinel) = auto: 3 if mask given & past_len==0 & S>1 (prefill), else 0
|
| 175 |
+
int64_t sparse = sparse_mode;
|
| 176 |
+
if (sparse < 0) {
|
| 177 |
+
sparse = (mask_tensor != nullptr && past_len == 0 && S > 1) ? 3 : 0;
|
| 178 |
+
}
|
| 179 |
+
fused_infer_attention_score(
|
| 180 |
+
stream, t_q_bsh.get(), t_k_bsh.get(), t_v_bsh.get(),
|
| 181 |
+
mask_tensor, {S}, {kv_len},
|
| 182 |
+
Hq, Hkv, scale, sparse, t_attn_out_bsh.get());
|
| 183 |
+
|
| 184 |
+
// 7. O projection: y = attn_out @ o_proj.T → [S, D]
|
| 185 |
+
auto t_attn_2d = make_contig_tensor(attn_out_scratch, ACL_BF16, {S, Q_DIM});
|
| 186 |
+
auto t_out = make_contig_tensor(x_out, ACL_BF16, {S, D});
|
| 187 |
+
linear_hf(stream, t_attn_2d.get(), w.o_proj.get(), ACL_BF16, D, Q_DIM, t_out.get());
|
| 188 |
+
|
| 189 |
+
// 8. TP AllReduce on x_out (row-parallel o_proj → SUM across ranks)
|
| 190 |
+
if (hccl_ctx && hccl_ctx->tp_size > 1) {
|
| 191 |
+
hccl_allreduce_bf16(*hccl_ctx, x_out, S * D, stream);
|
| 192 |
+
}
|
| 193 |
+
}
|
| 194 |
+
|
| 195 |
+
// MoE forward for a single layer. Residual NOT applied here.
|
| 196 |
+
//
|
| 197 |
+
// x_in [S, D] (hidden state, pre post_attention_layernorm)
|
| 198 |
+
// x_out [S, D] (MoE output)
|
| 199 |
+
//
|
| 200 |
+
// Scratch:
|
| 201 |
+
// xn_scratch: S * D * 2
|
| 202 |
+
// rstd_scratch: S * 4
|
| 203 |
+
// logits_scratch: S * E * 2
|
| 204 |
+
// topk_w_scratch: S * K * 2
|
| 205 |
+
// topk_idx_scratch: S * K * 4
|
| 206 |
+
// row_idx_scratch: S * K * 4 (gating output unused)
|
| 207 |
+
// expanded_x_scratch: TOTAL * D * 2
|
| 208 |
+
// expanded_ri_scratch:TOTAL * 4
|
| 209 |
+
// tpe_scratch: E * 8
|
| 210 |
+
// fwd_dev: TOTAL * 8
|
| 211 |
+
// gate_out_scratch: TOTAL * I * 2
|
| 212 |
+
// up_out_scratch: TOTAL * I * 2
|
| 213 |
+
// down_out_scratch: TOTAL * D * 2
|
| 214 |
+
// packed_scratch: TOTAL * D * 2
|
| 215 |
+
// weighted_scratch: S * K * D * 2
|
| 216 |
+
//
|
| 217 |
+
// where TOTAL = S * K, I = cfg.i_per_rank, E = cfg.num_experts, K = cfg.num_experts_per_tok.
|
| 218 |
+
//
|
| 219 |
+
// IMPORTANT: post_attention_layernorm weight in `attn_w` (not in LayerMoEWeights).
|
| 220 |
+
inline void moe_forward(
|
| 221 |
+
aclrtStream stream,
|
| 222 |
+
const ModelConfig& cfg,
|
| 223 |
+
LayerAttnWeights& attn_w, // for post_attention_layernorm
|
| 224 |
+
LayerMoEWeights& w,
|
| 225 |
+
void* x_in, int64_t S,
|
| 226 |
+
void* xn_scratch, void* rstd_scratch,
|
| 227 |
+
void* logits_scratch,
|
| 228 |
+
void* topk_w_scratch, void* topk_idx_scratch, void* row_idx_scratch,
|
| 229 |
+
void* expanded_x_scratch, void* expanded_ri_scratch, void* tpe_scratch,
|
| 230 |
+
void* fwd_scratch,
|
| 231 |
+
void* gate_out_scratch, void* up_out_scratch, void* down_out_scratch,
|
| 232 |
+
void* packed_scratch, void* weighted_scratch,
|
| 233 |
+
void* x_out,
|
| 234 |
+
HcclCtx* hccl_ctx = nullptr, // if tp_size > 1, AllReduce after reduce_sum
|
| 235 |
+
void* norm_sum_scratch = nullptr // S * 2 bytes — persistent buffer for topk_w normalize
|
| 236 |
+
) {
|
| 237 |
+
const int64_t D = cfg.hidden_size;
|
| 238 |
+
const int64_t I = cfg.i_per_rank;
|
| 239 |
+
const int64_t E = cfg.num_experts;
|
| 240 |
+
const int64_t K = cfg.num_experts_per_tok;
|
| 241 |
+
const double eps = cfg.rms_norm_eps;
|
| 242 |
+
const int64_t TOTAL = S * K;
|
| 243 |
+
|
| 244 |
+
// 1. post_attention_layernorm
|
| 245 |
+
auto t_x = make_contig_tensor(x_in, ACL_BF16, {S, D});
|
| 246 |
+
auto t_xn = make_contig_tensor(xn_scratch, ACL_BF16, {S, D});
|
| 247 |
+
auto t_lnw = make_contig_tensor(attn_w.post_attention_layernorm.get(), ACL_BF16, {D});
|
| 248 |
+
auto t_rstd = make_contig_tensor(rstd_scratch, ACL_FLOAT, {S});
|
| 249 |
+
rms_norm(stream, t_x.get(), t_lnw.get(), eps, t_xn.get(), t_rstd.get());
|
| 250 |
+
|
| 251 |
+
// 2. Router linear: logits = xn @ router.T → [S, E]
|
| 252 |
+
auto t_logits = make_contig_tensor(logits_scratch, ACL_BF16, {S, E});
|
| 253 |
+
linear_hf(stream, t_xn.get(), w.router.get(), ACL_BF16, E, D, t_logits.get());
|
| 254 |
+
|
| 255 |
+
// 3. TopK softmax
|
| 256 |
+
auto t_topk_w = make_contig_tensor(topk_w_scratch, ACL_BF16, {S, K});
|
| 257 |
+
auto t_topk_idx = make_contig_tensor(topk_idx_scratch, ACL_INT32, {S, K});
|
| 258 |
+
auto t_row_idx = make_contig_tensor(row_idx_scratch, ACL_INT32, {S, K});
|
| 259 |
+
moe_gating_topk_softmax(stream, t_logits.get(), K, t_topk_w.get(), t_topk_idx.get(), t_row_idx.get());
|
| 260 |
+
|
| 261 |
+
// 4. Device-side normalize topk weights (Qwen3 norm_topk_prob=true).
|
| 262 |
+
// sum = reduce_sum(topk_w, dim=-1, keepdim=true) # [S, 1] F32 in rstd_scratch
|
| 263 |
+
// sum += 1e-20
|
| 264 |
+
// sum_bf16 = cast(sum, BF16) # [S, 1] in norm_sum_scratch (caller-owned)
|
| 265 |
+
// topk_w /= sum_bf16 # broadcast divide
|
| 266 |
+
// No per-layer syncs — all scratch buffers persist across layers.
|
| 267 |
+
if (norm_sum_scratch) {
|
| 268 |
+
auto t_sum = make_contig_tensor(rstd_scratch, ACL_FLOAT, {S, 1});
|
| 269 |
+
auto t_sum_bf16 = make_contig_tensor(norm_sum_scratch, ACL_BF16, {S, 1});
|
| 270 |
+
reduce_sum(stream, t_topk_w.get(), {-1}, /*keep_dims=*/true, ACL_FLOAT, t_sum.get());
|
| 271 |
+
inplace_adds(stream, t_sum.get(), 1e-20);
|
| 272 |
+
cast(stream, t_sum.get(), ACL_BF16, t_sum_bf16.get());
|
| 273 |
+
div_tensor(stream, t_topk_w.get(), t_sum_bf16.get(), t_topk_w.get());
|
| 274 |
+
} else {
|
| 275 |
+
// Fallback: host-side normalize (for callers that didn't provide scratch).
|
| 276 |
+
ACL_CHECK(aclrtSynchronizeStream(stream));
|
| 277 |
+
std::vector<uint16_t> h_tw(S * K);
|
| 278 |
+
ACL_CHECK(aclrtMemcpy(h_tw.data(), S*K*2, topk_w_scratch, S*K*2, ACL_MEMCPY_DEVICE_TO_HOST));
|
| 279 |
+
for (int s = 0; s < S; s++) {
|
| 280 |
+
float sum = 0;
|
| 281 |
+
for (int k = 0; k < K; k++) {
|
| 282 |
+
uint32_t u = (uint32_t)h_tw[s*K + k] << 16;
|
| 283 |
+
float v; std::memcpy(&v, &u, 4);
|
| 284 |
+
sum += v;
|
| 285 |
+
}
|
| 286 |
+
sum += 1e-20f;
|
| 287 |
+
for (int k = 0; k < K; k++) {
|
| 288 |
+
uint32_t u = (uint32_t)h_tw[s*K + k] << 16;
|
| 289 |
+
float v; std::memcpy(&v, &u, 4);
|
| 290 |
+
v /= sum;
|
| 291 |
+
std::memcpy(&u, &v, 4);
|
| 292 |
+
h_tw[s*K + k] = (uint16_t)((u + 0x7FFF + ((u >> 16) & 1)) >> 16);
|
| 293 |
+
}
|
| 294 |
+
}
|
| 295 |
+
ACL_CHECK(aclrtMemcpy(topk_w_scratch, S*K*2, h_tw.data(), S*K*2, ACL_MEMCPY_HOST_TO_DEVICE));
|
| 296 |
+
}
|
| 297 |
+
|
| 298 |
+
// 5. MoE init routing
|
| 299 |
+
auto t_ex_x = make_contig_tensor(expanded_x_scratch, ACL_BF16, {TOTAL, D});
|
| 300 |
+
auto t_ex_ri = make_contig_tensor(expanded_ri_scratch, ACL_INT32, {TOTAL});
|
| 301 |
+
auto t_tpe = make_contig_tensor(tpe_scratch, ACL_INT64, {E});
|
| 302 |
+
moe_init_routing_v3(stream, t_xn.get(), t_topk_idx.get(),
|
| 303 |
+
E, TOTAL, t_ex_x.get(), t_ex_ri.get(), t_tpe.get());
|
| 304 |
+
|
| 305 |
+
// 6. GMM gate + up
|
| 306 |
+
auto t_gate_out = make_contig_tensor(gate_out_scratch, ACL_BF16, {TOTAL, I});
|
| 307 |
+
auto t_up_out = make_contig_tensor(up_out_scratch, ACL_BF16, {TOTAL, I});
|
| 308 |
+
auto t_w_gate = make_contig_tensor(w.gate_exps.get(), ACL_BF16, {E, D, I});
|
| 309 |
+
auto t_w_up = make_contig_tensor(w.up_exps.get(), ACL_BF16, {E, D, I});
|
| 310 |
+
grouped_matmul_v4(stream, t_ex_x.get(), t_w_gate.get(), t_tpe.get(), t_gate_out.get(), 1);
|
| 311 |
+
grouped_matmul_v4(stream, t_ex_x.get(), t_w_up.get(), t_tpe.get(), t_up_out.get(), 1);
|
| 312 |
+
|
| 313 |
+
// 7. SwiGLU: gate_out = silu(gate_out) * up_out
|
| 314 |
+
silu(stream, t_gate_out.get(), t_gate_out.get());
|
| 315 |
+
mul(stream, t_gate_out.get(), t_up_out.get(), t_gate_out.get());
|
| 316 |
+
|
| 317 |
+
// 8. GMM down
|
| 318 |
+
auto t_down_out = make_contig_tensor(down_out_scratch, ACL_BF16, {TOTAL, D});
|
| 319 |
+
auto t_w_down = make_contig_tensor(w.down_exps.get(), ACL_BF16, {E, I, D});
|
| 320 |
+
grouped_matmul_v4(stream, t_gate_out.get(), t_w_down.get(), t_tpe.get(), t_down_out.get(), 1);
|
| 321 |
+
|
| 322 |
+
// 9. Device-side finalize: build forward perm via two consecutive argsorts on topk_idx.
|
| 323 |
+
// No host sync — safe for graph capture.
|
| 324 |
+
// inv_fwd = argsort(topk_idx.flat) // each (n,k) → sorted position (primary key: expert_id)
|
| 325 |
+
// fwd = argsort(inv_fwd) // inverse perm — what IndexSelect needs
|
| 326 |
+
// Stability: aclnnArgsort preserves input order for equal keys; flat index = n*K + k orders
|
| 327 |
+
// ties by n-then-k, matching our previous manual sort convention.
|
| 328 |
+
//
|
| 329 |
+
// Scratch for inv_fwd: reuse first TOTAL*8 bytes of weighted_scratch (gets overwritten
|
| 330 |
+
// by the subsequent mul op, so aliasing is safe).
|
| 331 |
+
{
|
| 332 |
+
auto t_topk_idx_flat = make_contig_tensor(topk_idx_scratch, ACL_INT32, {TOTAL});
|
| 333 |
+
auto t_inv_fwd = make_contig_tensor(weighted_scratch, ACL_INT64, {TOTAL});
|
| 334 |
+
auto t_fwd_64 = make_contig_tensor(fwd_scratch, ACL_INT64, {TOTAL});
|
| 335 |
+
argsort(stream, t_topk_idx_flat.get(), /*dim=*/0, /*descending=*/false, t_inv_fwd.get());
|
| 336 |
+
argsort(stream, t_inv_fwd.get(), /*dim=*/0, /*descending=*/false, t_fwd_64.get());
|
| 337 |
+
}
|
| 338 |
+
auto t_fwd = make_contig_tensor(fwd_scratch, ACL_INT64, {TOTAL});
|
| 339 |
+
auto t_packed = make_contig_tensor(packed_scratch, ACL_BF16, {TOTAL, D});
|
| 340 |
+
index_select(stream, t_down_out.get(), 0, t_fwd.get(), t_packed.get());
|
| 341 |
+
|
| 342 |
+
auto t_packed_3d = make_contig_tensor(packed_scratch, ACL_BF16, {S, K, D});
|
| 343 |
+
auto t_topk_w_3d = make_contig_tensor(topk_w_scratch, ACL_BF16, {S, K, 1});
|
| 344 |
+
auto t_weighted = make_contig_tensor(weighted_scratch, ACL_BF16, {S, K, D});
|
| 345 |
+
mul(stream, t_packed_3d.get(), t_topk_w_3d.get(), t_weighted.get());
|
| 346 |
+
|
| 347 |
+
auto t_out = make_contig_tensor(x_out, ACL_BF16, {S, D});
|
| 348 |
+
reduce_sum(stream, t_weighted.get(), {1}, false, ACL_BF16, t_out.get());
|
| 349 |
+
|
| 350 |
+
// TP AllReduce on MoE output (column-parallel experts → SUM partial intermediate outputs)
|
| 351 |
+
if (hccl_ctx && hccl_ctx->tp_size > 1) {
|
| 352 |
+
hccl_allreduce_bf16(*hccl_ctx, x_out, S * D, stream);
|
| 353 |
+
}
|
| 354 |
+
}
|
|
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// hccl_comm.h — minimal HCCL wrapper for TP=N AllReduce.
|
| 2 |
+
//
|
| 3 |
+
// Multi-process mode (each rank is a separate process, device 0 each):
|
| 4 |
+
// - Rank 0 calls HcclGetRootInfo, writes to /tmp/hccl_root_info.bin
|
| 5 |
+
// - Rank 1..N-1 wait for that file, read it
|
| 6 |
+
// - All ranks call HcclCommInitRootInfo → shared HcclComm
|
| 7 |
+
// - allreduce() does in-place HcclAllReduce with SUM op
|
| 8 |
+
//
|
| 9 |
+
// Launcher sets HCCL_WHITELIST_DISABLE=1, ASCEND_RT_VISIBLE_DEVICES=<rank>, etc.
|
| 10 |
+
#pragma once
|
| 11 |
+
#include <hccl/hccl.h>
|
| 12 |
+
#include <hccl/hccl_types.h>
|
| 13 |
+
#include <acl/acl.h>
|
| 14 |
+
|
| 15 |
+
#include <chrono>
|
| 16 |
+
#include <cstdio>
|
| 17 |
+
#include <cstring>
|
| 18 |
+
#include <string>
|
| 19 |
+
#include <thread>
|
| 20 |
+
|
| 21 |
+
#define HCCL_ROOT_INFO_PATH "/tmp/hccl_root_info.bin"
|
| 22 |
+
|
| 23 |
+
struct HcclCtx {
|
| 24 |
+
HcclComm comm = nullptr;
|
| 25 |
+
int tp_size = 1;
|
| 26 |
+
int tp_rank = 0;
|
| 27 |
+
bool initialized = false;
|
| 28 |
+
};
|
| 29 |
+
|
| 30 |
+
inline bool hccl_init(HcclCtx& ctx, int tp_size, int tp_rank) {
|
| 31 |
+
if (tp_size <= 1) { ctx.tp_size = 1; ctx.tp_rank = 0; ctx.initialized = true; return true; }
|
| 32 |
+
ctx.tp_size = tp_size;
|
| 33 |
+
ctx.tp_rank = tp_rank;
|
| 34 |
+
|
| 35 |
+
HcclRootInfo rootInfo;
|
| 36 |
+
std::memset(&rootInfo, 0, sizeof(rootInfo));
|
| 37 |
+
|
| 38 |
+
if (tp_rank == 0) {
|
| 39 |
+
if (HcclGetRootInfo(&rootInfo) != HCCL_SUCCESS) {
|
| 40 |
+
fprintf(stderr, "[HCCL] HcclGetRootInfo failed\n"); return false;
|
| 41 |
+
}
|
| 42 |
+
FILE* f = fopen(HCCL_ROOT_INFO_PATH, "wb");
|
| 43 |
+
if (!f) { fprintf(stderr, "[HCCL] cannot write %s\n", HCCL_ROOT_INFO_PATH); return false; }
|
| 44 |
+
fwrite(&rootInfo, sizeof(rootInfo), 1, f);
|
| 45 |
+
fclose(f);
|
| 46 |
+
} else {
|
| 47 |
+
bool found = false;
|
| 48 |
+
for (int r = 0; r < 600; r++) { // 60s timeout
|
| 49 |
+
FILE* f = fopen(HCCL_ROOT_INFO_PATH, "rb");
|
| 50 |
+
if (f) {
|
| 51 |
+
size_t rd = fread(&rootInfo, 1, sizeof(rootInfo), f);
|
| 52 |
+
fclose(f);
|
| 53 |
+
if (rd == sizeof(rootInfo)) { found = true; break; }
|
| 54 |
+
}
|
| 55 |
+
std::this_thread::sleep_for(std::chrono::milliseconds(100));
|
| 56 |
+
}
|
| 57 |
+
if (!found) { fprintf(stderr, "[HCCL] rank %d timeout waiting for root info\n", tp_rank); return false; }
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
+
HcclResult r = HcclCommInitRootInfo((uint32_t)tp_size, &rootInfo, (uint32_t)tp_rank, &ctx.comm);
|
| 61 |
+
if (r != HCCL_SUCCESS) {
|
| 62 |
+
fprintf(stderr, "[HCCL] HcclCommInitRootInfo failed: %d (rank=%d)\n", (int)r, tp_rank);
|
| 63 |
+
return false;
|
| 64 |
+
}
|
| 65 |
+
ctx.initialized = true;
|
| 66 |
+
fprintf(stderr, "[HCCL] rank %d/%d comm OK\n", tp_rank, tp_size);
|
| 67 |
+
return true;
|
| 68 |
+
}
|
| 69 |
+
|
| 70 |
+
// In-place AllReduce SUM on BF16 tensor. dtype = HCCL_DATA_TYPE_BFP16.
|
| 71 |
+
inline bool hccl_allreduce_bf16(const HcclCtx& ctx, void* data, int64_t count, aclrtStream stream) {
|
| 72 |
+
if (!ctx.initialized) return false;
|
| 73 |
+
if (ctx.tp_size <= 1) return true; // no-op
|
| 74 |
+
|
| 75 |
+
HcclResult r = HcclAllReduce(data, data, (uint64_t)count,
|
| 76 |
+
HCCL_DATA_TYPE_BFP16, HCCL_REDUCE_SUM,
|
| 77 |
+
ctx.comm, stream);
|
| 78 |
+
if (r != HCCL_SUCCESS) {
|
| 79 |
+
fprintf(stderr, "[HCCL] AllReduce failed: %d\n", (int)r);
|
| 80 |
+
return false;
|
| 81 |
+
}
|
| 82 |
+
return true;
|
| 83 |
+
}
|
| 84 |
+
|
| 85 |
+
// Broadcast buffer from root (rank 0) to all ranks. Used to share prompt tokens across ranks.
|
| 86 |
+
// `data_dev` must be device memory. dtype generic (e.g., HCCL_DATA_TYPE_INT32).
|
| 87 |
+
inline bool hccl_broadcast(const HcclCtx& ctx, void* data_dev, int64_t count,
|
| 88 |
+
HcclDataType dtype, uint32_t root, aclrtStream stream) {
|
| 89 |
+
if (!ctx.initialized) return false;
|
| 90 |
+
if (ctx.tp_size <= 1) return true;
|
| 91 |
+
|
| 92 |
+
HcclResult r = HcclBroadcast(data_dev, (uint64_t)count, dtype, root, ctx.comm, stream);
|
| 93 |
+
if (r != HCCL_SUCCESS) {
|
| 94 |
+
fprintf(stderr, "[HCCL] Broadcast failed: %d\n", (int)r);
|
| 95 |
+
return false;
|
| 96 |
+
}
|
| 97 |
+
return true;
|
| 98 |
+
}
|
| 99 |
+
|
| 100 |
+
inline void hccl_shutdown(HcclCtx& ctx) {
|
| 101 |
+
if (ctx.comm) {
|
| 102 |
+
HcclCommDestroy(ctx.comm);
|
| 103 |
+
ctx.comm = nullptr;
|
| 104 |
+
}
|
| 105 |
+
ctx.initialized = false;
|
| 106 |
+
}
|
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// model_config.h — Qwen3 hparams loaded from HF config.json, plus TP-derived per-rank sizes.
|
| 2 |
+
#pragma once
|
| 3 |
+
#include <cstdint>
|
| 4 |
+
#include <string>
|
| 5 |
+
|
| 6 |
+
struct ModelConfig {
|
| 7 |
+
// ---- Raw hparams from config.json ----
|
| 8 |
+
int64_t vocab_size = 0;
|
| 9 |
+
int64_t hidden_size = 0; // D
|
| 10 |
+
int64_t intermediate_size = 0; // dense FFN (not used for MoE layers; kept for completeness)
|
| 11 |
+
int64_t moe_intermediate_size = 0; // I per expert
|
| 12 |
+
int64_t num_hidden_layers = 0; // = 94 for Qwen3-235B
|
| 13 |
+
int64_t num_attention_heads = 0; // = 64
|
| 14 |
+
int64_t num_key_value_heads = 0; // = 4 (GQA)
|
| 15 |
+
int64_t head_dim = 0; // = 128
|
| 16 |
+
int64_t num_experts = 0; // = 128
|
| 17 |
+
int64_t num_experts_per_tok = 0; // top_k = 8
|
| 18 |
+
int64_t max_position_embeddings = 0;
|
| 19 |
+
float rope_theta = 0.0f;
|
| 20 |
+
float rms_norm_eps = 1e-6f;
|
| 21 |
+
bool norm_topk_prob = true;
|
| 22 |
+
bool tie_word_embeddings = false;
|
| 23 |
+
int64_t bos_token_id = 0;
|
| 24 |
+
int64_t eos_token_id = 0;
|
| 25 |
+
|
| 26 |
+
// ---- TP configuration ----
|
| 27 |
+
int tp_size = 1;
|
| 28 |
+
int tp_rank = 0;
|
| 29 |
+
|
| 30 |
+
// ---- Derived per-rank sizes ----
|
| 31 |
+
// Attention Q: split along num_heads (head-parallel)
|
| 32 |
+
// n_heads_per_rank = num_attention_heads / tp_size
|
| 33 |
+
// q_dim_per_rank = n_heads_per_rank * head_dim
|
| 34 |
+
int64_t n_heads_per_rank = 0;
|
| 35 |
+
int64_t q_dim_per_rank = 0;
|
| 36 |
+
|
| 37 |
+
// Attention KV: GQA with num_kv_heads < tp_size needs special handling.
|
| 38 |
+
// For Qwen3-235B: num_kv_heads = 4, tp_size = 16 → each KV head is replicated 4× across ranks.
|
| 39 |
+
// Simple scheme: each rank computes ALL kv heads (small, 4 × 128 = 512 features)
|
| 40 |
+
// then slices attention output for its own q heads.
|
| 41 |
+
// Alternative: split KV heads if tp_size <= num_kv_heads.
|
| 42 |
+
int64_t n_kv_heads_per_rank = 0;
|
| 43 |
+
int64_t kv_dim_per_rank = 0;
|
| 44 |
+
|
| 45 |
+
// MoE: intermediate dim split. Each rank holds 1/tp_size of experts' intermediate_size.
|
| 46 |
+
// i_per_rank = moe_intermediate_size / tp_size
|
| 47 |
+
int64_t i_per_rank = 0;
|
| 48 |
+
|
| 49 |
+
bool load_from_json(const std::string& path);
|
| 50 |
+
void compute_derived(int tp_size, int tp_rank);
|
| 51 |
+
std::string describe() const;
|
| 52 |
+
};
|
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// rope.h — Manual HF-style RoPE using basic aclnn ops.
|
| 2 |
+
//
|
| 3 |
+
// Formula: q_out = q * cos + rotate_half(q) * sin
|
| 4 |
+
// where rotate_half(q) = concat(-q[..., d/2:], q[..., :d/2], dim=-1)
|
| 5 |
+
//
|
| 6 |
+
// Tensor layout: q/k [B, S, N, Dh] BF16, cos/sin [1, S, Dh] BF16
|
| 7 |
+
// (cos/sin are broadcast across B and N dims)
|
| 8 |
+
//
|
| 9 |
+
#pragma once
|
| 10 |
+
#include "acl_common.h"
|
| 11 |
+
#include "aclnn_ops.h"
|
| 12 |
+
#include <aclnnop/aclnn_apply_rotary_pos_emb_v2.h>
|
| 13 |
+
|
| 14 |
+
// Fused RoPE via aclnnApplyRotaryPosEmbV2 — replaces the 8-op manual version with a single
|
| 15 |
+
// op, saving ~7 launches per layer × 94 layers = ~658 launches/token. Validated on 910 initial:
|
| 16 |
+
// layout=1 + rotaryMode="half" matches HF rotate_half semantics (rel=1.24e-3 vs manual).
|
| 17 |
+
//
|
| 18 |
+
// q_data: [B, S, Nq, Dh] BF16 (modified in place)
|
| 19 |
+
// k_data: [B, S, Nk, Dh] BF16 (modified in place)
|
| 20 |
+
// cos_data / sin_data: [1, S, 1, Dh] BF16 (single contiguous buffer slice from RopeCache)
|
| 21 |
+
inline void apply_rope_fused(aclrtStream stream,
|
| 22 |
+
void* q_data, int64_t B, int64_t S, int64_t Nq, int64_t Dh,
|
| 23 |
+
void* k_data, int64_t Nk,
|
| 24 |
+
void* cos_data, void* sin_data) {
|
| 25 |
+
const aclDataType dt = ACL_BF16;
|
| 26 |
+
auto t_q = make_contig_tensor(q_data, dt, {B, S, Nq, Dh});
|
| 27 |
+
auto t_k = make_contig_tensor(k_data, dt, {B, S, Nk, Dh});
|
| 28 |
+
auto t_cos = make_contig_tensor(cos_data, dt, {1, S, 1, Dh});
|
| 29 |
+
auto t_sin = make_contig_tensor(sin_data, dt, {1, S, 1, Dh});
|
| 30 |
+
uint64_t ws = 0; aclOpExecutor* exec = nullptr;
|
| 31 |
+
char mode[] = "half";
|
| 32 |
+
ACLNN_CHECK(aclnnApplyRotaryPosEmbV2GetWorkspaceSize(
|
| 33 |
+
t_q.get(), t_k.get(), t_cos.get(), t_sin.get(),
|
| 34 |
+
/*layout=*/1, mode, &ws, &exec));
|
| 35 |
+
void* wp = (ws > 0) ? _lca_pool().alloc(ws) : nullptr;
|
| 36 |
+
ACLNN_CHECK(aclnnApplyRotaryPosEmbV2(wp, ws, exec, stream));
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
// Apply RoPE in-place to q and k.
|
| 40 |
+
// q_data: pointer to [B, S, Nq, Dh] BF16 (modified in place)
|
| 41 |
+
// k_data: pointer to [B, S, Nk, Dh] BF16 (modified in place)
|
| 42 |
+
// cos_data, sin_data: [1, S, Dh] BF16
|
| 43 |
+
// scratch_data: pointer to contiguous [B, S, max(Nq,Nk), Dh] BF16 scratch buffer for rotate_half
|
| 44 |
+
inline void apply_rope_manual(aclrtStream stream,
|
| 45 |
+
void* q_data, int64_t B, int64_t S, int64_t Nq, int64_t Dh,
|
| 46 |
+
void* k_data, int64_t Nk,
|
| 47 |
+
void* cos_data, void* sin_data,
|
| 48 |
+
void* scratch_data) {
|
| 49 |
+
const aclDataType dt = ACL_BF16;
|
| 50 |
+
const size_t elem = 2;
|
| 51 |
+
const int64_t halfDh = Dh / 2;
|
| 52 |
+
|
| 53 |
+
auto process = [&](void* x_data, int64_t N) {
|
| 54 |
+
// Strides in elements (row-major [B, S, N, Dh]):
|
| 55 |
+
// stride = [S*N*Dh, N*Dh, Dh, 1]
|
| 56 |
+
const std::vector<int64_t> full_shape = {B, S, N, Dh};
|
| 57 |
+
const std::vector<int64_t> full_stride = {S*N*Dh, N*Dh, Dh, 1};
|
| 58 |
+
const std::vector<int64_t> half_shape = {B, S, N, halfDh};
|
| 59 |
+
const std::vector<int64_t> half_stride = full_stride; // same leading 3 strides
|
| 60 |
+
|
| 61 |
+
// View x as full
|
| 62 |
+
auto t_x = make_acl_tensor(x_data, dt, full_shape, full_stride);
|
| 63 |
+
|
| 64 |
+
// View of x left half and right half (shifted pointers, same layout, last dim half)
|
| 65 |
+
auto t_x_left = make_acl_tensor(x_data, dt, half_shape, half_stride);
|
| 66 |
+
auto t_x_right = make_acl_tensor((char*)x_data + halfDh*elem, dt, half_shape, half_stride);
|
| 67 |
+
|
| 68 |
+
// rotate_half buffer view (contiguous [B, S, N, Dh])
|
| 69 |
+
const std::vector<int64_t> rh_stride = {S*N*Dh, N*Dh, Dh, 1};
|
| 70 |
+
auto t_rh = make_acl_tensor(scratch_data, dt, full_shape, rh_stride);
|
| 71 |
+
auto t_rh_left = make_acl_tensor(scratch_data, dt, half_shape, rh_stride);
|
| 72 |
+
auto t_rh_right = make_acl_tensor((char*)scratch_data + halfDh*elem, dt, half_shape, rh_stride);
|
| 73 |
+
|
| 74 |
+
// rh[..., :Dh/2] = -x[..., Dh/2:]
|
| 75 |
+
neg(stream, t_x_right.get(), t_rh_left.get());
|
| 76 |
+
// rh[..., Dh/2:] = x[..., :Dh/2]
|
| 77 |
+
inplace_copy(stream, t_rh_right.get(), t_x_left.get());
|
| 78 |
+
|
| 79 |
+
// cos/sin views broadcastable to [B, S, N, Dh]
|
| 80 |
+
// Original storage: [1, S, Dh]. For broadcast, use shape [1, S, 1, Dh] with strides [0, Dh, 0, 1].
|
| 81 |
+
auto t_cos = make_acl_tensor(cos_data, dt, {1, S, 1, Dh}, {0, Dh, 0, 1});
|
| 82 |
+
auto t_sin = make_acl_tensor(sin_data, dt, {1, S, 1, Dh}, {0, Dh, 0, 1});
|
| 83 |
+
|
| 84 |
+
// q_rot = q * cos + rh * sin (use addcmul: q *= cos, then q += rh * sin)
|
| 85 |
+
// Compute tmp = q * cos (fresh buffer needed; use scratch_data is occupied by rh)
|
| 86 |
+
// Better: multiply x in place: x *= cos, then x += rh * sin
|
| 87 |
+
// aclnnMul with x as both in and out is inplace.
|
| 88 |
+
mul(stream, t_x.get(), t_cos.get(), t_x.get()); // x = x * cos
|
| 89 |
+
addcmul(stream, t_x.get(), t_rh.get(), t_sin.get(), 1); // x += 1 * (rh * sin)
|
| 90 |
+
};
|
| 91 |
+
|
| 92 |
+
process(q_data, Nq);
|
| 93 |
+
process(k_data, Nk);
|
| 94 |
+
}
|
|
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// runner.h — multi-layer transformer Runner for Qwen3-235B-A22B.
|
| 2 |
+
//
|
| 3 |
+
// Owns: shared weights, per-layer attention + MoE weights, KV cache, scratch buffers.
|
| 4 |
+
// Provides: prefill(tokens) and decode(new_token) methods returning logits [vocab] on device.
|
| 5 |
+
//
|
| 6 |
+
// Memory budget at TP=1 for testing a SUBSET of layers (num_layers_to_load <= 94). Full 94-layer
|
| 7 |
+
// inference requires TP=16 where per-rank MoE fits ~28GB.
|
| 8 |
+
#pragma once
|
| 9 |
+
#include "acl_common.h"
|
| 10 |
+
#include "acl_runtime.h"
|
| 11 |
+
#include "aclnn_ops.h"
|
| 12 |
+
#include "device_weights.h"
|
| 13 |
+
#include "engine.h"
|
| 14 |
+
#include "hccl_comm.h"
|
| 15 |
+
#include "model_config.h"
|
| 16 |
+
#include "safetensors_loader.h"
|
| 17 |
+
|
| 18 |
+
#include <vector>
|
| 19 |
+
|
| 20 |
+
class Runner {
|
| 21 |
+
public:
|
| 22 |
+
Runner() = default;
|
| 23 |
+
~Runner() = default;
|
| 24 |
+
Runner(const Runner&) = delete;
|
| 25 |
+
Runner& operator=(const Runner&) = delete;
|
| 26 |
+
|
| 27 |
+
// Initialize runtime, open safetensors, load shared weights. tp_size/tp_rank configure
|
| 28 |
+
// MoE + attention sharding. num_layers is how many transformer blocks to load (1..94).
|
| 29 |
+
// max_seq is the maximum sequence length (for KV cache allocation).
|
| 30 |
+
bool init(const std::string& model_dir, int tp_size, int tp_rank,
|
| 31 |
+
int num_layers_to_load, int64_t max_seq, int device_id = 0);
|
| 32 |
+
|
| 33 |
+
// Prefill: ingest S>=1 tokens, produces logits [vocab] for the LAST position. Populates KV
|
| 34 |
+
// cache starting at position 0. `hidden_out` optionally returns the final hidden state [S, D].
|
| 35 |
+
bool prefill(const int32_t* tokens, int64_t S, DeviceBuffer& logits_out);
|
| 36 |
+
|
| 37 |
+
// Decode: take 1 new token, produce logits [vocab] from the new position.
|
| 38 |
+
bool decode(int32_t token, DeviceBuffer& logits_out);
|
| 39 |
+
|
| 40 |
+
// Batched decode: take S tokens as "candidate verify batch" at positions [past_len..past_len+S),
|
| 41 |
+
// produce logits [S, vocab]. Uses causal-with-past mask (token i sees past+tokens[0..i]).
|
| 42 |
+
// Foundation for speculative decoding / PLD.
|
| 43 |
+
// tokens: [S] int32
|
| 44 |
+
// S: 1 .. 16
|
| 45 |
+
// all_logits_out: will hold S * vocab_size * 2 bytes BF16, row-major [S, V]
|
| 46 |
+
// Updates past_len by +S on success.
|
| 47 |
+
bool decode_batch(const int32_t* tokens, int64_t S, DeviceBuffer& all_logits_out);
|
| 48 |
+
|
| 49 |
+
// Warmup: run N dummy decode() calls (resetting cache) to pre-compile aclnn executors,
|
| 50 |
+
// warm HCCL collective buffers, and stabilize NPU thermals. Improves first-N-token latency
|
| 51 |
+
// by ~1 s (especially noticeable on short generations or REPL cold start).
|
| 52 |
+
// Call after init(); safe to call multiple times. Does NOT affect past_len.
|
| 53 |
+
void warmup(int iterations = 3);
|
| 54 |
+
|
| 55 |
+
// Accessors
|
| 56 |
+
const ModelConfig& cfg() const { return cfg_; }
|
| 57 |
+
aclrtStream stream() { return rt_.stream(); }
|
| 58 |
+
int64_t past_len() const { return past_len_; }
|
| 59 |
+
void reset_cache() { past_len_ = 0; }
|
| 60 |
+
// Rewind past_len by n. Used by speculative decoding to discard rejected draft tokens'
|
| 61 |
+
// KV cache entries (they'll be overwritten by subsequent writes).
|
| 62 |
+
void rewind_cache(int64_t n) { if (n > 0 && n <= past_len_) past_len_ -= n; }
|
| 63 |
+
HcclCtx& hccl_ctx() { return hccl_ctx_; }
|
| 64 |
+
|
| 65 |
+
// Profiling: set via LCA_PROFILE=1 env in main_cli. If enabled, decode() accumulates
|
| 66 |
+
// per-phase wall-clock ms into the timer accumulators below.
|
| 67 |
+
bool profile_enabled = false;
|
| 68 |
+
double t_embed_ms = 0, t_layers_ms = 0, t_final_ms = 0;
|
| 69 |
+
int64_t profile_calls = 0;
|
| 70 |
+
void print_profile_summary() const;
|
| 71 |
+
|
| 72 |
+
private:
|
| 73 |
+
// One-layer forward: x_in [S, D] → x_out [S, D] via attention + residual + MoE + residual.
|
| 74 |
+
// Uses this layer's KV cache starting at past_len; caller updates past_len after each call.
|
| 75 |
+
// batch_decode_mode: true for S>1 at past_len>0 (spec decoding) — uses custom causal mask
|
| 76 |
+
// with past instead of the 2048×2048 prefill mask.
|
| 77 |
+
void layer_forward_(int layer_idx, int64_t S, void* x_in, void* x_out,
|
| 78 |
+
bool batch_decode_mode = false);
|
| 79 |
+
|
| 80 |
+
// Build causal-with-past mask in batch_mask_dev_ for decode_batch at current past_len.
|
| 81 |
+
// Shape [1, 1, S, past_len+S] bool, mask[i, j] = 1 iff j > past_len+i.
|
| 82 |
+
void build_batch_decode_mask_(int64_t S);
|
| 83 |
+
|
| 84 |
+
// Final: final_norm + lm_head on last position → logits [vocab].
|
| 85 |
+
void final_logits_(void* hidden_last /*[1, D]*/, DeviceBuffer& logits_out);
|
| 86 |
+
|
| 87 |
+
// Batched final: final_norm + lm_head on [S, D] → logits [S, V].
|
| 88 |
+
void final_logits_batch_(void* hidden /*[S, D]*/, int64_t S, DeviceBuffer& logits_out);
|
| 89 |
+
|
| 90 |
+
AclRuntime rt_;
|
| 91 |
+
SafetensorsLoader st_;
|
| 92 |
+
ModelConfig cfg_;
|
| 93 |
+
HcclCtx hccl_ctx_;
|
| 94 |
+
int num_layers_ = 0;
|
| 95 |
+
int64_t max_seq_ = 0;
|
| 96 |
+
|
| 97 |
+
SharedWeights shared_;
|
| 98 |
+
std::vector<LayerAttnWeights> attn_;
|
| 99 |
+
std::vector<LayerMoEWeights> moe_;
|
| 100 |
+
|
| 101 |
+
// Per-layer KV cache
|
| 102 |
+
std::vector<DeviceBuffer> k_cache_;
|
| 103 |
+
std::vector<DeviceBuffer> v_cache_;
|
| 104 |
+
|
| 105 |
+
// Scratch (reallocated per-call sized by current S)
|
| 106 |
+
DeviceBuffer q_sc_, k_sc_, v_sc_, xn_sc_, rstd_sc_, rope_sc_, attn_fias_sc_, attn_out_sc_;
|
| 107 |
+
DeviceBuffer moe_xn_, moe_rstd_, moe_logits_;
|
| 108 |
+
DeviceBuffer moe_topk_w_, moe_topk_idx_, moe_row_idx_;
|
| 109 |
+
DeviceBuffer moe_ex_x_, moe_ex_ri_, moe_tpe_;
|
| 110 |
+
DeviceBuffer moe_fwd_;
|
| 111 |
+
DeviceBuffer moe_gate_, moe_up_, moe_down_;
|
| 112 |
+
DeviceBuffer moe_packed_, moe_weighted_, moe_out_;
|
| 113 |
+
DeviceBuffer moe_norm_sum_; // BF16 [S, 1] for on-device topk_w normalize
|
| 114 |
+
DeviceBuffer x_buf_a_, x_buf_b_; // ping-pong for residual chain
|
| 115 |
+
|
| 116 |
+
// Causal mask for prefill (2048 x 2048 bool); decode uses nullptr
|
| 117 |
+
DeviceBuffer prefill_mask_dev_;
|
| 118 |
+
|
| 119 |
+
// Batch decode mask: S_MAX × KV_MAX bool, where mask[i, j] = 1 (masked out) if
|
| 120 |
+
// j > past_len + i. Built on-demand per-call (past_len changes).
|
| 121 |
+
DeviceBuffer batch_mask_dev_;
|
| 122 |
+
|
| 123 |
+
// Pre-computed RoPE cos/sin table (sized for max_seq_)
|
| 124 |
+
RopeCache rope_cache_;
|
| 125 |
+
|
| 126 |
+
int64_t past_len_ = 0;
|
| 127 |
+
int64_t cur_S_capacity_ = 0; // scratch sized for this many tokens
|
| 128 |
+
};
|
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// safetensors_loader.h — lazy multi-shard safetensors reader.
|
| 2 |
+
//
|
| 3 |
+
// Usage:
|
| 4 |
+
// SafetensorsLoader loader;
|
| 5 |
+
// loader.open("/path/to/model_dir"); // parses index.json + all shard headers
|
| 6 |
+
// auto meta = loader.get("model.layers.0.self_attn.q_proj.weight");
|
| 7 |
+
// const void* host_ptr = loader.data_ptr(meta); // mmap-backed, host memory
|
| 8 |
+
// // copy to device: aclrtMemcpy(d_ptr, n, host_ptr, n, ACL_MEMCPY_HOST_TO_DEVICE);
|
| 9 |
+
//
|
| 10 |
+
// Files are mmap'd on first access and unmapped at destruction.
|
| 11 |
+
//
|
| 12 |
+
#pragma once
|
| 13 |
+
#include <cstdint>
|
| 14 |
+
#include <map>
|
| 15 |
+
#include <string>
|
| 16 |
+
#include <unordered_map>
|
| 17 |
+
#include <vector>
|
| 18 |
+
|
| 19 |
+
struct TensorMeta {
|
| 20 |
+
std::string name;
|
| 21 |
+
std::string dtype; // "BF16", "F16", "F32", "I32", "I64"
|
| 22 |
+
std::vector<int64_t> shape;
|
| 23 |
+
int shard_id = -1; // index into SafetensorsLoader::shards_
|
| 24 |
+
size_t offset = 0; // byte offset within shard (after 8B header_len + JSON header)
|
| 25 |
+
size_t nbytes = 0;
|
| 26 |
+
};
|
| 27 |
+
|
| 28 |
+
struct ShardFile {
|
| 29 |
+
std::string path;
|
| 30 |
+
int fd = -1;
|
| 31 |
+
void* mmap_ptr = nullptr;
|
| 32 |
+
size_t mmap_size = 0;
|
| 33 |
+
size_t data_base = 0; // byte offset to first tensor data within file
|
| 34 |
+
};
|
| 35 |
+
|
| 36 |
+
class SafetensorsLoader {
|
| 37 |
+
public:
|
| 38 |
+
SafetensorsLoader();
|
| 39 |
+
~SafetensorsLoader();
|
| 40 |
+
|
| 41 |
+
// Opens a HuggingFace model directory. Returns false on failure.
|
| 42 |
+
// Expects: <dir>/model.safetensors.index.json + model-XXXXX-of-YYYYY.safetensors
|
| 43 |
+
bool open(const std::string& model_dir);
|
| 44 |
+
|
| 45 |
+
// Get tensor metadata. Returns nullptr if name not found.
|
| 46 |
+
const TensorMeta* get(const std::string& name) const;
|
| 47 |
+
|
| 48 |
+
// Return host pointer to tensor's raw bytes (mmap-backed). Null if not found or mmap failed.
|
| 49 |
+
const void* data_ptr(const TensorMeta& m);
|
| 50 |
+
const void* data_ptr(const std::string& name);
|
| 51 |
+
|
| 52 |
+
// Enumerate all tensor names (stable order = lexicographic).
|
| 53 |
+
std::vector<std::string> list_tensor_names() const;
|
| 54 |
+
|
| 55 |
+
// Stats
|
| 56 |
+
size_t tensor_count() const { return tensors_.size(); }
|
| 57 |
+
size_t shard_count() const { return shards_.size(); }
|
| 58 |
+
size_t total_bytes() const;
|
| 59 |
+
|
| 60 |
+
private:
|
| 61 |
+
bool parse_shard_header_(int shard_id);
|
| 62 |
+
bool mmap_shard_(int shard_id);
|
| 63 |
+
|
| 64 |
+
std::string model_dir_;
|
| 65 |
+
std::vector<ShardFile> shards_;
|
| 66 |
+
std::map<std::string, TensorMeta> tensors_; // ordered for determinism
|
| 67 |
+
};
|
| 68 |
+
|
| 69 |
+
// ---- Helpers ----
|
| 70 |
+
|
| 71 |
+
// Convert safetensors dtype string to element byte size.
|
| 72 |
+
inline size_t sdtype_size(const std::string& s) {
|
| 73 |
+
if (s == "F32" || s == "I32") return 4;
|
| 74 |
+
if (s == "F16" || s == "BF16" || s == "I16") return 2;
|
| 75 |
+
if (s == "F64" || s == "I64") return 8;
|
| 76 |
+
if (s == "I8" || s == "U8" || s == "BOOL") return 1;
|
| 77 |
+
return 0;
|
| 78 |
+
}
|
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// tokenizer.h — minimal Qwen3 tokenizer.
|
| 2 |
+
//
|
| 3 |
+
// M2-phase1: decode() is native C++ (simple vocab lookup). encode() is a Python subprocess
|
| 4 |
+
// (one-time cost at prompt setup). Native BPE encode is a future item.
|
| 5 |
+
//
|
| 6 |
+
#pragma once
|
| 7 |
+
#include <string>
|
| 8 |
+
#include <vector>
|
| 9 |
+
#include <cstdint>
|
| 10 |
+
|
| 11 |
+
class Tokenizer {
|
| 12 |
+
public:
|
| 13 |
+
bool load(const std::string& vocab_bin_path);
|
| 14 |
+
|
| 15 |
+
// Decode a single token id to UTF-8 string.
|
| 16 |
+
std::string decode(int token_id) const;
|
| 17 |
+
|
| 18 |
+
// Decode list of token ids to concatenated UTF-8 string.
|
| 19 |
+
std::string decode(const std::vector<int>& token_ids) const;
|
| 20 |
+
|
| 21 |
+
// Encode prompt to token ids. Uses a Python subprocess since Qwen3 needs proper BPE.
|
| 22 |
+
// The subprocess call takes ~200ms but is only invoked once per prompt.
|
| 23 |
+
std::vector<int> encode_via_python(const std::string& model_dir,
|
| 24 |
+
const std::string& prompt,
|
| 25 |
+
bool apply_chat_template = false) const;
|
| 26 |
+
|
| 27 |
+
// Encode a multi-turn conversation by applying the model's chat template. Each pair is
|
| 28 |
+
// (role, content) — typical roles: "system", "user", "assistant". Uses Python subprocess.
|
| 29 |
+
std::vector<int> encode_conversation_via_python(
|
| 30 |
+
const std::string& model_dir,
|
| 31 |
+
const std::vector<std::pair<std::string, std::string>>& conversation,
|
| 32 |
+
bool add_generation_prompt = true) const;
|
| 33 |
+
|
| 34 |
+
size_t size() const { return id_to_bytes_.size(); }
|
| 35 |
+
|
| 36 |
+
private:
|
| 37 |
+
std::vector<std::string> id_to_bytes_; // id -> raw utf-8 bytes
|
| 38 |
+
};
|
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// workspace_pool.h — reusable aclnn workspace buffer pool.
|
| 2 |
+
//
|
| 3 |
+
// Problem: every aclnn op does `aclrtMalloc(workspace)` + `aclrtFree`. For decode at 94 layers
|
| 4 |
+
// × ~30 ops = 2820 mallocs/frees per token, this is significant overhead.
|
| 5 |
+
//
|
| 6 |
+
// Solution: pool of DeviceBuffers, grow-only. Pool returns a pointer >= requested size.
|
| 7 |
+
// Most ops reuse the SAME buffer since they don't overlap on-stream (serial execution).
|
| 8 |
+
//
|
| 9 |
+
// Thread safety: not thread-safe. One pool per Runner (one thread).
|
| 10 |
+
#pragma once
|
| 11 |
+
#include "acl_common.h"
|
| 12 |
+
#include <algorithm>
|
| 13 |
+
#include <vector>
|
| 14 |
+
|
| 15 |
+
class WorkspacePool {
|
| 16 |
+
public:
|
| 17 |
+
WorkspacePool() = default;
|
| 18 |
+
~WorkspacePool() = default;
|
| 19 |
+
WorkspacePool(const WorkspacePool&) = delete;
|
| 20 |
+
WorkspacePool& operator=(const WorkspacePool&) = delete;
|
| 21 |
+
|
| 22 |
+
// Return a device pointer of at least `bytes`. Reuses the current buffer
|
| 23 |
+
// if it's big enough; otherwise grows by allocating a new one and
|
| 24 |
+
// **retaining old buffers** (async kernels may still be reading them —
|
| 25 |
+
// freeing too early would corrupt in-flight workspaces).
|
| 26 |
+
//
|
| 27 |
+
// Periodically call `reset_after_sync()` when the stream is idle to
|
| 28 |
+
// reclaim all-but-largest buffers and reset grow count.
|
| 29 |
+
void* alloc(size_t bytes) {
|
| 30 |
+
if (bytes == 0) return nullptr;
|
| 31 |
+
if (current_size_ < bytes) {
|
| 32 |
+
// Keep old buffer alive (don't free!) — aclnn kernels may still use it.
|
| 33 |
+
old_bufs_.push_back(std::move(buf_));
|
| 34 |
+
buf_.alloc(bytes);
|
| 35 |
+
current_size_ = bytes;
|
| 36 |
+
grow_count_++;
|
| 37 |
+
}
|
| 38 |
+
return buf_.get();
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
size_t current_size() const { return current_size_; }
|
| 42 |
+
size_t grow_count() const { return grow_count_; }
|
| 43 |
+
size_t retained_count() const { return old_bufs_.size(); }
|
| 44 |
+
|
| 45 |
+
// Call only when the stream is guaranteed idle (e.g., after aclrtSynchronizeStream).
|
| 46 |
+
// Drops all retained older buffers, freeing device memory. Current active buffer kept.
|
| 47 |
+
void reset_after_sync() {
|
| 48 |
+
old_bufs_.clear();
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
void clear() {
|
| 52 |
+
old_bufs_.clear();
|
| 53 |
+
buf_ = DeviceBuffer();
|
| 54 |
+
current_size_ = 0;
|
| 55 |
+
grow_count_ = 0;
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
private:
|
| 59 |
+
DeviceBuffer buf_; // current active (largest so far)
|
| 60 |
+
std::vector<DeviceBuffer> old_bufs_; // older, smaller — still live until stream sync
|
| 61 |
+
size_t current_size_ = 0;
|
| 62 |
+
size_t grow_count_ = 0;
|
| 63 |
+
};
|
| 64 |
+
|
| 65 |
+
// Convenience: per-stream RAII guard that acts like a `DeviceBuffer` but draws from pool.
|
| 66 |
+
// Used in aclnn_ops.h wrappers as a drop-in replacement for the local DeviceBuffer.
|
| 67 |
+
class PoolBuffer {
|
| 68 |
+
public:
|
| 69 |
+
// Fallback mode: if pool is nullptr, allocate own buffer (current behavior).
|
| 70 |
+
// Pool mode: return pool's shared pointer.
|
| 71 |
+
PoolBuffer(WorkspacePool* pool, size_t bytes) {
|
| 72 |
+
if (pool) {
|
| 73 |
+
ptr_ = pool->alloc(bytes);
|
| 74 |
+
} else if (bytes > 0) {
|
| 75 |
+
local_.alloc(bytes);
|
| 76 |
+
ptr_ = local_.get();
|
| 77 |
+
}
|
| 78 |
+
}
|
| 79 |
+
void* get() { return ptr_; }
|
| 80 |
+
|
| 81 |
+
private:
|
| 82 |
+
DeviceBuffer local_; // only used when pool is null
|
| 83 |
+
void* ptr_ = nullptr;
|
| 84 |
+
};
|
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
# bench_hccl.sh — HCCL 参数矩阵 benchmark for TG
|
| 3 |
+
#
|
| 4 |
+
# 遍历 HCCL_ALGO × HCCL_BUFFSIZE 组合,每组 N_RUNS 次取中位数,找最佳配置。
|
| 5 |
+
# 固定 prompt + seed=0 + n_predict=200 保证可比性。
|
| 6 |
+
set -u
|
| 7 |
+
cd "$(dirname "$0")/.."
|
| 8 |
+
|
| 9 |
+
MODEL="${MODEL_DIR:-/path/to/Qwen3-235B-A22B-Instruct-2507-BF16}"
|
| 10 |
+
BIN="./build/qwen3-moe-aclnn"
|
| 11 |
+
LAUNCH="./scripts/tp_launch.sh"
|
| 12 |
+
TP="${TP_SIZE:-16}"
|
| 13 |
+
N_PREDICT="${N_PREDICT:-150}"
|
| 14 |
+
N_RUNS="${N_RUNS:-2}"
|
| 15 |
+
PROMPT="${PROMPT:-The history of artificial intelligence spans several decades and}"
|
| 16 |
+
VOCAB="tokenizer_data/vocab.bin"
|
| 17 |
+
|
| 18 |
+
OUT=/tmp/bench_hccl_results.csv
|
| 19 |
+
echo "algo,buffsize,runs,best_tgs" > $OUT
|
| 20 |
+
|
| 21 |
+
run_one() {
|
| 22 |
+
local algo="$1" buf="$2"
|
| 23 |
+
local tgs=()
|
| 24 |
+
for r in $(seq 1 $N_RUNS); do
|
| 25 |
+
export HCCL_ALGO="$algo" HCCL_BUFFSIZE="$buf"
|
| 26 |
+
local out
|
| 27 |
+
out=$(${LAUNCH} ${TP} ${BIN} --model-dir "$MODEL" \
|
| 28 |
+
--prompt "$PROMPT" --n-predict $N_PREDICT \
|
| 29 |
+
--vocab "$VOCAB" --seed 0 2>&1 | grep "decode :" | awk '{print $(NF-2)}')
|
| 30 |
+
tgs+=("${out:-0}")
|
| 31 |
+
done
|
| 32 |
+
local sorted=($(printf '%s\n' "${tgs[@]}" | sort -n))
|
| 33 |
+
local best="${sorted[-1]}"
|
| 34 |
+
local csv="$algo,$buf,${tgs[*]},$best"
|
| 35 |
+
echo "$csv" | sed 's/ /|/g' >> $OUT
|
| 36 |
+
printf " %-22s buf=%-4s %s best=%s\n" \
|
| 37 |
+
"${algo:-(auto)}" "$buf" "${tgs[*]}" "$best"
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
# Matrix
|
| 41 |
+
ALGOS=("" "level0:ring" "level0:fullmesh")
|
| 42 |
+
BUFSIZES=("100" "200" "400")
|
| 43 |
+
|
| 44 |
+
echo "HCCL matrix: ${#ALGOS[@]} algos × ${#BUFSIZES[@]} buffsizes × ${N_RUNS} runs each"
|
| 45 |
+
echo "Results → $OUT"
|
| 46 |
+
echo ""
|
| 47 |
+
|
| 48 |
+
for algo in "${ALGOS[@]}"; do
|
| 49 |
+
for buf in "${BUFSIZES[@]}"; do
|
| 50 |
+
run_one "$algo" "$buf"
|
| 51 |
+
done
|
| 52 |
+
done
|
| 53 |
+
|
| 54 |
+
echo ""
|
| 55 |
+
echo "====== Summary (sorted by best TG) ======"
|
| 56 |
+
(head -1 $OUT; tail -n +2 $OUT | sort -t, -k4 -gr) | column -t -s,
|
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
# bench_hccl_adv.sh — 进阶 HCCL 参数调优
|
| 3 |
+
# 在已确定的 ring:200 baseline 上加入 OP_EXPANSION_MODE=AIV 等 knob
|
| 4 |
+
set -u
|
| 5 |
+
cd "$(dirname "$0")/.."
|
| 6 |
+
|
| 7 |
+
MODEL="${MODEL_DIR:-/path/to/Qwen3-235B-A22B-Instruct-2507-BF16}"
|
| 8 |
+
BIN="./build/qwen3-moe-aclnn"
|
| 9 |
+
LAUNCH="./scripts/tp_launch.sh"
|
| 10 |
+
TP=16
|
| 11 |
+
N_PREDICT=200
|
| 12 |
+
N_RUNS=2
|
| 13 |
+
PROMPT="The history of artificial intelligence spans several decades and"
|
| 14 |
+
VOCAB="tokenizer_data/vocab.bin"
|
| 15 |
+
|
| 16 |
+
OUT=/tmp/bench_hccl_adv.csv
|
| 17 |
+
echo "config,run1,run2,best,median" > $OUT
|
| 18 |
+
|
| 19 |
+
run_one() {
|
| 20 |
+
local name="$1"; shift
|
| 21 |
+
# remaining args are env assignments: KEY=VALUE ...
|
| 22 |
+
local tgs=()
|
| 23 |
+
for r in $(seq 1 $N_RUNS); do
|
| 24 |
+
local out
|
| 25 |
+
# set env vars for this run
|
| 26 |
+
local env_cmd=""
|
| 27 |
+
for a in "$@"; do env_cmd="$env_cmd $a"; done
|
| 28 |
+
out=$(env HCCL_ALGO=level0:ring HCCL_BUFFSIZE=200 $@ \
|
| 29 |
+
${LAUNCH} ${TP} ${BIN} --model-dir "$MODEL" \
|
| 30 |
+
--prompt "$PROMPT" --n-predict $N_PREDICT \
|
| 31 |
+
--vocab "$VOCAB" --seed 0 2>&1 | grep "decode :" | awk '{print $(NF-2)}')
|
| 32 |
+
tgs+=("${out:-0}")
|
| 33 |
+
done
|
| 34 |
+
local sorted=($(printf '%s\n' "${tgs[@]}" | sort -n))
|
| 35 |
+
local best="${sorted[-1]}"
|
| 36 |
+
local median="${sorted[$((${#sorted[@]}/2))]}"
|
| 37 |
+
echo "$name,${tgs[0]},${tgs[1]},$best,$median" >> $OUT
|
| 38 |
+
printf " %-40s %s best=%s median=%s\n" "$name" "${tgs[*]}" "$best" "$median"
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
echo "Adv HCCL bench: baseline ring:200 + additional knobs"
|
| 42 |
+
echo "Results → $OUT"
|
| 43 |
+
echo ""
|
| 44 |
+
|
| 45 |
+
run_one "baseline (ring+200 only)"
|
| 46 |
+
|
| 47 |
+
run_one "+ OP_EXPANSION_MODE=AIV" HCCL_OP_EXPANSION_MODE=AIV
|
| 48 |
+
run_one "+ OP_BASE_FFTS_MODE=1" HCCL_OP_BASE_FFTS_MODE_ENABLE=1
|
| 49 |
+
run_one "+ OP_EXPANSION=AIV + FFTS=1" HCCL_OP_EXPANSION_MODE=AIV HCCL_OP_BASE_FFTS_MODE_ENABLE=1
|
| 50 |
+
run_one "+ OP_EXPANSION=AIV + BUF=256" HCCL_OP_EXPANSION_MODE=AIV HCCL_BUFFSIZE=256
|
| 51 |
+
run_one "+ OP_EXPANSION=AIV + BUF=512" HCCL_OP_EXPANSION_MODE=AIV HCCL_BUFFSIZE=512
|
| 52 |
+
run_one "+ OP_EXPANSION=AIV + ALGO=fullmesh" HCCL_OP_EXPANSION_MODE=AIV HCCL_ALGO=level0:fullmesh
|
| 53 |
+
|
| 54 |
+
echo ""
|
| 55 |
+
echo "====== Sorted by best TG ======"
|
| 56 |
+
(head -1 $OUT; tail -n +2 $OUT | sort -t, -k4 -gr) | column -t -s,
|
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
# bench_hccl_adv2.sh — layer 2 env knob exploration on top of AIV+FFTS=1 baseline.
|
| 3 |
+
# Target: break past 25 t/s MUST barrier.
|
| 4 |
+
set -u
|
| 5 |
+
cd "$(dirname "$0")/.."
|
| 6 |
+
|
| 7 |
+
MODEL="${MODEL_DIR:-/path/to/Qwen3-235B-A22B-Instruct-2507-BF16}"
|
| 8 |
+
BIN="./build/qwen3-moe-aclnn"
|
| 9 |
+
LAUNCH="./scripts/tp_launch.sh"
|
| 10 |
+
TP=16
|
| 11 |
+
N_PREDICT=200
|
| 12 |
+
N_RUNS=3
|
| 13 |
+
LONG_PROMPT="Write a very long detailed essay about artificial intelligence, machine learning, deep learning and their applications in modern society. Include historical context, current state of the art, and future predictions."
|
| 14 |
+
VOCAB="tokenizer_data/vocab.bin"
|
| 15 |
+
|
| 16 |
+
OUT=/tmp/bench_hccl_adv2.csv
|
| 17 |
+
echo "config,runs,best,median" > $OUT
|
| 18 |
+
|
| 19 |
+
run_one() {
|
| 20 |
+
local name="$1"; shift
|
| 21 |
+
local tgs=()
|
| 22 |
+
for r in $(seq 1 $N_RUNS); do
|
| 23 |
+
local out
|
| 24 |
+
out=$(env HCCL_ALGO=level0:ring HCCL_BUFFSIZE=200 \
|
| 25 |
+
HCCL_OP_EXPANSION_MODE=AIV HCCL_OP_BASE_FFTS_MODE_ENABLE=1 \
|
| 26 |
+
"$@" \
|
| 27 |
+
${LAUNCH} ${TP} ${BIN} --model-dir "$MODEL" \
|
| 28 |
+
--prompt "$LONG_PROMPT" --n-predict $N_PREDICT \
|
| 29 |
+
--vocab "$VOCAB" --seed 0 --no-stream 2>&1 \
|
| 30 |
+
| grep "decode :" | awk '{print $(NF-2)}')
|
| 31 |
+
tgs+=("${out:-0}")
|
| 32 |
+
done
|
| 33 |
+
local sorted=($(printf '%s\n' "${tgs[@]}" | sort -n))
|
| 34 |
+
local best="${sorted[-1]}"
|
| 35 |
+
local median="${sorted[$((${#sorted[@]}/2))]}"
|
| 36 |
+
echo "$name,${tgs[*]},$best,$median" | tr ' ' '|' | sed 's/|/,/' | sed 's/|/ /g' >> $OUT
|
| 37 |
+
printf " %-40s %s best=%s median=%s\n" "$name" "${tgs[*]}" "$best" "$median"
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
echo "Bench: AIV+FFTS baseline + single additional knob"
|
| 41 |
+
echo "$N_RUNS runs × $N_PREDICT tokens"
|
| 42 |
+
echo ""
|
| 43 |
+
|
| 44 |
+
run_one "baseline (AIV + FFTS)"
|
| 45 |
+
|
| 46 |
+
run_one "+ TASK_QUEUE_ENABLE=1" TASK_QUEUE_ENABLE=1
|
| 47 |
+
run_one "+ TASK_QUEUE_ENABLE=2" TASK_QUEUE_ENABLE=2
|
| 48 |
+
run_one "+ HCCL_BUFFSIZE=256" HCCL_BUFFSIZE=256
|
| 49 |
+
run_one "+ HCCL_DETERMINISTIC=false" HCCL_DETERMINISTIC=false
|
| 50 |
+
run_one "+ HCCL_INTRA_ROCE_ENABLE=1" HCCL_INTRA_ROCE_ENABLE=1
|
| 51 |
+
run_one "+ HCCL_CLUSTER_TIMEOUT=600" HCCL_CLUSTER_TIMEOUT=600
|
| 52 |
+
run_one "+ ASCEND_LAUNCH_BLOCKING=0" ASCEND_LAUNCH_BLOCKING=0
|
| 53 |
+
|
| 54 |
+
echo ""
|
| 55 |
+
echo "====== Sorted by best TG ======"
|
| 56 |
+
(head -1 $OUT; tail -n +2 $OUT | sort -t, -k3 -gr) | column -t -s,
|
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
# bench_pld.sh — sweep K × n-gram with corrected causal-with-past mask.
|
| 3 |
+
# Measures TG + accept rate stability across N_RUNS per config.
|
| 4 |
+
set -u
|
| 5 |
+
cd "$(dirname "$0")/.."
|
| 6 |
+
|
| 7 |
+
MODEL="${MODEL_DIR:-/path/to/Qwen3-235B-A22B-Instruct-2507-BF16}"
|
| 8 |
+
BIN="./build/qwen3-moe-aclnn"
|
| 9 |
+
LAUNCH="./scripts/tp_launch.sh"
|
| 10 |
+
TP=16
|
| 11 |
+
N_PREDICT=200
|
| 12 |
+
N_RUNS="${N_RUNS:-3}"
|
| 13 |
+
PROMPT="${PROMPT:-Write a long Python function that computes the Fibonacci sequence with memoization, extensive comments, and type hints.}"
|
| 14 |
+
VOCAB="tokenizer_data/vocab.bin"
|
| 15 |
+
|
| 16 |
+
OUT=/tmp/bench_pld.csv
|
| 17 |
+
echo "k,ngram,run_tgs,best,median,avg_accept" > $OUT
|
| 18 |
+
|
| 19 |
+
run_one() {
|
| 20 |
+
local k="$1" ng="$2"
|
| 21 |
+
local tgs=() accs=()
|
| 22 |
+
for r in $(seq 1 $N_RUNS); do
|
| 23 |
+
local output
|
| 24 |
+
output=$(${LAUNCH} ${TP} ${BIN} --model-dir "$MODEL" \
|
| 25 |
+
--prompt "$PROMPT" --n-predict $N_PREDICT --max-seq 512 \
|
| 26 |
+
--vocab "$VOCAB" --seed 0 --no-stream \
|
| 27 |
+
--pld --pld-k $k --pld-ngram $ng 2>&1)
|
| 28 |
+
local tg
|
| 29 |
+
tg=$(echo "$output" | grep "decode :" | awk '{print $(NF-2)}')
|
| 30 |
+
local acc
|
| 31 |
+
acc=$(echo "$output" | grep "\[pld\]" | grep -oE "avg=[0-9.]+" | cut -d= -f2)
|
| 32 |
+
tgs+=("${tg:-0}")
|
| 33 |
+
accs+=("${acc:-0}")
|
| 34 |
+
done
|
| 35 |
+
local sorted=($(printf '%s\n' "${tgs[@]}" | sort -n))
|
| 36 |
+
local n=${#sorted[@]}
|
| 37 |
+
local best="${sorted[-1]}"
|
| 38 |
+
local median="${sorted[$((n/2))]}"
|
| 39 |
+
local accs_avg=$(printf '%s\n' "${accs[@]}" | awk '{s+=$1} END {printf "%.2f", s/NR}')
|
| 40 |
+
echo "$k,$ng,$(IFS=/; echo "${tgs[*]}"),$best,$median,$accs_avg" >> $OUT
|
| 41 |
+
printf " K=%-2d ng=%-1d runs=[%s] best=%s median=%s accept_avg=%s\n" \
|
| 42 |
+
"$k" "$ng" "${tgs[*]}" "$best" "$median" "$accs_avg"
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
echo "PLD sweep on '$PROMPT' ($N_RUNS runs × $N_PREDICT tokens)"
|
| 46 |
+
echo ""
|
| 47 |
+
|
| 48 |
+
for k in 2 4 6 8 12; do
|
| 49 |
+
for ng in 1 2 3; do
|
| 50 |
+
run_one $k $ng
|
| 51 |
+
done
|
| 52 |
+
done
|
| 53 |
+
|
| 54 |
+
# Baseline for reference
|
| 55 |
+
echo ""
|
| 56 |
+
echo "Baseline (no PLD):"
|
| 57 |
+
tgs=()
|
| 58 |
+
for r in $(seq 1 $N_RUNS); do
|
| 59 |
+
tg=$(${LAUNCH} ${TP} ${BIN} --model-dir "$MODEL" \
|
| 60 |
+
--prompt "$PROMPT" --n-predict $N_PREDICT --max-seq 512 \
|
| 61 |
+
--vocab "$VOCAB" --seed 0 --no-stream 2>&1 | grep "decode :" | awk '{print $(NF-2)}')
|
| 62 |
+
tgs+=("${tg:-0}")
|
| 63 |
+
done
|
| 64 |
+
sorted=($(printf '%s\n' "${tgs[@]}" | sort -n))
|
| 65 |
+
echo " baseline: ${tgs[*]} median=${sorted[$((${#sorted[@]}/2))]}"
|
| 66 |
+
|
| 67 |
+
echo ""
|
| 68 |
+
echo "====== Sorted by median TG ======"
|
| 69 |
+
(head -1 $OUT; tail -n +2 $OUT | sort -t, -k5 -gr) | column -t -s,
|
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
# bench_pld_k.sh — isolated K sweep with FIXED K (no adaptive) to characterize raw K effect.
|
| 3 |
+
# Larger K = more draft candidates per verify. Peak observed accept=7.38 suggests K=8 not saturated.
|
| 4 |
+
set -u
|
| 5 |
+
cd "$(dirname "$0")/.."
|
| 6 |
+
|
| 7 |
+
MODEL="${MODEL_DIR:-/path/to/Qwen3-235B-A22B-Instruct-2507-BF16}"
|
| 8 |
+
BIN="./build/qwen3-moe-aclnn"
|
| 9 |
+
LAUNCH="./scripts/tp_launch.sh"
|
| 10 |
+
TP=16
|
| 11 |
+
N_PREDICT=200
|
| 12 |
+
N_RUNS=3
|
| 13 |
+
PROMPT="Write a long Python function that computes the Fibonacci sequence with memoization, extensive comments, and type hints."
|
| 14 |
+
VOCAB="tokenizer_data/vocab.bin"
|
| 15 |
+
|
| 16 |
+
OUT=/tmp/bench_pld_k.csv
|
| 17 |
+
echo "k,runs,median,max,avg_accept" > $OUT
|
| 18 |
+
|
| 19 |
+
for K in 4 6 8 10 12 16; do
|
| 20 |
+
tgs=() accs=()
|
| 21 |
+
for r in $(seq 1 $N_RUNS); do
|
| 22 |
+
out=$(${LAUNCH} ${TP} ${BIN} --model-dir "$MODEL" \
|
| 23 |
+
--prompt "$PROMPT" --n-predict $N_PREDICT --max-seq 512 \
|
| 24 |
+
--vocab "$VOCAB" --seed 0 --no-stream \
|
| 25 |
+
--pld --pld-k $K --pld-ngram 1 --pld-fixed-k 2>&1)
|
| 26 |
+
tg=$(echo "$out" | grep "decode :" | awk '{print $(NF-2)}')
|
| 27 |
+
acc=$(echo "$out" | grep "\[pld\]" | grep -oE "avg=[0-9.]+" | cut -d= -f2)
|
| 28 |
+
tgs+=("${tg:-0}")
|
| 29 |
+
accs+=("${acc:-0}")
|
| 30 |
+
done
|
| 31 |
+
sorted=($(printf '%s\n' "${tgs[@]}" | sort -n))
|
| 32 |
+
median="${sorted[$((${#sorted[@]}/2))]}"
|
| 33 |
+
max="${sorted[-1]}"
|
| 34 |
+
accs_avg=$(printf '%s\n' "${accs[@]}" | awk '{s+=$1} END {printf "%.2f", s/NR}')
|
| 35 |
+
echo "$K,$(IFS=/; echo "${tgs[*]}"),$median,$max,$accs_avg" >> $OUT
|
| 36 |
+
printf " K=%-2d runs=[%s] median=%s max=%s accept=%s\n" "$K" "${tgs[*]}" "$median" "$max" "$accs_avg"
|
| 37 |
+
done
|
| 38 |
+
|
| 39 |
+
echo ""
|
| 40 |
+
echo "====== Sorted by median ======"
|
| 41 |
+
(head -1 $OUT; tail -n +2 $OUT | sort -t, -k3 -gr) | column -t -s,
|
|
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
# bench_pld_safe.sh — PLD benchmark with output correctness check.
|
| 3 |
+
# Unlike bench_tg.sh (which only reports TG numbers), this wrapper also inspects the
|
| 4 |
+
# generated text for degeneration signals (consecutive identical tokens / very low
|
| 5 |
+
# distinct-token ratio in the tail) and flags runs whose high TG came from dead-loop
|
| 6 |
+
# output rather than real acceleration.
|
| 7 |
+
#
|
| 8 |
+
# Usage: ./scripts/bench_pld_safe.sh [N_RUNS] [PROMPT_FILE]
|
| 9 |
+
# Prompts with "|" separator: "tag|prompt text"
|
| 10 |
+
# Default: tests multiple prompt classes and reports which ones PLD helps safely.
|
| 11 |
+
set -u
|
| 12 |
+
cd "$(dirname "$0")/.."
|
| 13 |
+
|
| 14 |
+
MODEL="${MODEL_DIR:-/path/to/Qwen3-235B-A22B-Instruct-2507-BF16}"
|
| 15 |
+
BIN="./build/qwen3-moe-aclnn"
|
| 16 |
+
N_RUNS="${1:-3}"
|
| 17 |
+
N_PREDICT="${N_PREDICT:-120}"
|
| 18 |
+
VOCAB="tokenizer_data/vocab.bin"
|
| 19 |
+
|
| 20 |
+
# Default prompt suite: one per class. Override via PROMPTS env or arg 2 (file with "tag|prompt" per line).
|
| 21 |
+
default_prompts=(
|
| 22 |
+
"story|Once upon a time, in a small village,"
|
| 23 |
+
"factual|The capital of France is"
|
| 24 |
+
"code|Write a Python function that computes Fibonacci."
|
| 25 |
+
"essay|The history of artificial intelligence spans several decades and"
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
if [ "${2:-}" != "" ] && [ -f "${2:-}" ]; then
|
| 29 |
+
mapfile -t prompts < "$2"
|
| 30 |
+
else
|
| 31 |
+
prompts=("${default_prompts[@]}")
|
| 32 |
+
fi
|
| 33 |
+
|
| 34 |
+
# ----- Correctness classifier -----
|
| 35 |
+
# Reads generated text from stdin, returns:
|
| 36 |
+
# OK — no loop signals
|
| 37 |
+
# LOOP_N — N+ consecutive identical non-space words detected
|
| 38 |
+
# LOW_DIVERSITY — tail 40 words have < 10 distinct words (heavy repetition)
|
| 39 |
+
classify_output() {
|
| 40 |
+
awk '
|
| 41 |
+
{
|
| 42 |
+
# Tokenize on whitespace; strip punct at edges for comparison.
|
| 43 |
+
n = split($0, w, /[[:space:]]+/);
|
| 44 |
+
for (i = 1; i <= n; i++) {
|
| 45 |
+
gsub(/^[[:punct:]]+|[[:punct:]]+$/, "", w[i]);
|
| 46 |
+
if (w[i] == "") continue;
|
| 47 |
+
words[++nw] = tolower(w[i]);
|
| 48 |
+
}
|
| 49 |
+
}
|
| 50 |
+
END {
|
| 51 |
+
if (nw < 5) { print "OK"; exit }
|
| 52 |
+
# consecutive-same detection
|
| 53 |
+
run = 1; max_run = 1;
|
| 54 |
+
for (i = 2; i <= nw; i++) {
|
| 55 |
+
if (words[i] == words[i-1]) { run++; if (run > max_run) max_run = run; }
|
| 56 |
+
else run = 1;
|
| 57 |
+
}
|
| 58 |
+
if (max_run >= 6) { printf "LOOP_%d\n", max_run; exit }
|
| 59 |
+
|
| 60 |
+
# tail diversity: last 40 words
|
| 61 |
+
tail_start = nw - 39; if (tail_start < 1) tail_start = 1;
|
| 62 |
+
delete seen;
|
| 63 |
+
distinct = 0;
|
| 64 |
+
for (i = tail_start; i <= nw; i++) {
|
| 65 |
+
if (!(words[i] in seen)) { seen[words[i]] = 1; distinct++; }
|
| 66 |
+
}
|
| 67 |
+
tail_n = nw - tail_start + 1;
|
| 68 |
+
if (tail_n >= 20 && distinct < 10) {
|
| 69 |
+
printf "LOW_DIVERSITY_%d/%d\n", distinct, tail_n;
|
| 70 |
+
exit;
|
| 71 |
+
}
|
| 72 |
+
print "OK";
|
| 73 |
+
}'
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
run_once() {
|
| 77 |
+
local prompt="$1"
|
| 78 |
+
local extra_flags="$2"
|
| 79 |
+
# Launch. The binary prints to stdout: rank/cli headers, runner loading lines,
|
| 80 |
+
# generated text (--no-stream), then perf lines. pld/warn go to stderr.
|
| 81 |
+
local stdout_file=$(mktemp)
|
| 82 |
+
local stderr_file=$(mktemp)
|
| 83 |
+
# Ensure no lockfile leftover.
|
| 84 |
+
ssh_cleanup_lockfile
|
| 85 |
+
./scripts/tp_launch.sh 16 $BIN --model-dir "$MODEL" \
|
| 86 |
+
--prompt "$prompt" --n-predict $N_PREDICT \
|
| 87 |
+
--vocab "$VOCAB" --seed 0 --no-stream --temperature 0 \
|
| 88 |
+
$extra_flags 1>"$stdout_file" 2>"$stderr_file"
|
| 89 |
+
# TG lives on stdout (from printf in binary).
|
| 90 |
+
local tg=$(grep "\[perf\] decode" "$stdout_file" | awk '{print $(NF-2)}')
|
| 91 |
+
# Generated text: the line that begins with the prompt (--no-stream echoes prompt+text).
|
| 92 |
+
local gen_text=$(grep -F -- "$prompt" "$stdout_file" | grep -v '^\[' | tail -1)
|
| 93 |
+
local stripped="${gen_text#$prompt}"
|
| 94 |
+
local verdict=$(echo "$stripped" | classify_output)
|
| 95 |
+
local has_warn=""
|
| 96 |
+
if grep -q "\[warn\]" "$stderr_file"; then has_warn="WARN"; fi
|
| 97 |
+
local pld_line=$(grep "\[pld\]" "$stderr_file" | tail -1 | sed 's/^\[pld\] //')
|
| 98 |
+
rm -f "$stdout_file" "$stderr_file"
|
| 99 |
+
echo "${tg:-0}|${verdict}|${has_warn}|${pld_line}"
|
| 100 |
+
}
|
| 101 |
+
|
| 102 |
+
ssh_cleanup_lockfile() {
|
| 103 |
+
rm -f /tmp/hccl_root_info.bin 2>/dev/null || true
|
| 104 |
+
}
|
| 105 |
+
|
| 106 |
+
bench_prompt() {
|
| 107 |
+
local tag="$1"; local prompt="$2"; local flags="$3"
|
| 108 |
+
echo ""
|
| 109 |
+
echo "=== [$tag] $(echo "$prompt" | head -c 50)... (flags: ${flags:-none}) ==="
|
| 110 |
+
local tgs=() verdicts=() warns=() plds=()
|
| 111 |
+
for r in $(seq 1 $N_RUNS); do
|
| 112 |
+
result=$(run_once "$prompt" "$flags")
|
| 113 |
+
IFS='|' read -r tg verdict warn pld <<< "$result"
|
| 114 |
+
printf " run %d: TG=%s verdict=%s %s\n" "$r" "$tg" "$verdict" "$warn"
|
| 115 |
+
[ -n "$pld" ] && printf " %s\n" "$pld"
|
| 116 |
+
tgs+=("${tg:-0}"); verdicts+=("$verdict"); warns+=("$warn")
|
| 117 |
+
rm -f /tmp/hccl_root_info.bin
|
| 118 |
+
done
|
| 119 |
+
# Split good vs degraded
|
| 120 |
+
local good_tgs=() bad_tgs=()
|
| 121 |
+
for i in "${!tgs[@]}"; do
|
| 122 |
+
if [ "${verdicts[$i]}" = "OK" ]; then good_tgs+=("${tgs[$i]}"); else bad_tgs+=("${tgs[$i]}"); fi
|
| 123 |
+
done
|
| 124 |
+
local n_good=${#good_tgs[@]}
|
| 125 |
+
local n_bad=${#bad_tgs[@]}
|
| 126 |
+
echo " → $n_good/$N_RUNS OK, $n_bad/$N_RUNS degraded"
|
| 127 |
+
if [ $n_good -gt 0 ]; then
|
| 128 |
+
local mean=$(printf '%s\n' "${good_tgs[@]}" | awk '{s+=$1} END {printf "%.2f", s/NR}')
|
| 129 |
+
echo " → OK mean TG: $mean t/s (values: ${good_tgs[*]})"
|
| 130 |
+
fi
|
| 131 |
+
if [ $n_bad -gt 0 ]; then
|
| 132 |
+
local bad_mean=$(printf '%s\n' "${bad_tgs[@]}" | awk '{s+=$1} END {printf "%.2f", s/NR}')
|
| 133 |
+
echo " → degraded mean TG: $bad_mean t/s (DO NOT REPORT as speedup) (values: ${bad_tgs[*]})"
|
| 134 |
+
fi
|
| 135 |
+
}
|
| 136 |
+
|
| 137 |
+
echo "bench_pld_safe: $N_RUNS runs × $N_PREDICT tokens per prompt; comparing [no-pld, pld+guard, pld+no-guard]"
|
| 138 |
+
|
| 139 |
+
for entry in "${prompts[@]}"; do
|
| 140 |
+
tag="${entry%%|*}"
|
| 141 |
+
prompt="${entry#*|}"
|
| 142 |
+
bench_prompt "$tag/base" "$prompt" ""
|
| 143 |
+
bench_prompt "$tag/pld+guard" "$prompt" "--pld"
|
| 144 |
+
bench_prompt "$tag/pld-raw" "$prompt" "--pld --pld-no-guard"
|
| 145 |
+
done
|
| 146 |
+
|
| 147 |
+
echo ""
|
| 148 |
+
echo "=========================================================="
|
| 149 |
+
echo "Interpretation:"
|
| 150 |
+
echo " OK mean TG is the only honest number to report."
|
| 151 |
+
echo " Any 'degraded' result with high TG is a dead-loop artifact."
|
| 152 |
+
echo " Expected: pld+guard matches or beats base on creative/story prompts,"
|
| 153 |
+
echo " matches base on factual/code prompts (drafts rejected → fallback to single decode)."
|
| 154 |
+
echo " pld-raw (no guard) on repetitive prompts produces 'degraded' with high TG."
|
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
# bench_tg.sh — stable TG measurement: N runs × 200 tokens, drop cold-starts, report median.
|
| 3 |
+
#
|
| 4 |
+
# Usage: ./scripts/bench_tg.sh [N_RUNS] (default 5)
|
| 5 |
+
# LCA_WARMUP=3 ./scripts/bench_tg.sh (with warmup enabled)
|
| 6 |
+
set -u
|
| 7 |
+
cd "$(dirname "$0")/.."
|
| 8 |
+
|
| 9 |
+
MODEL="${MODEL_DIR:-/path/to/Qwen3-235B-A22B-Instruct-2507-BF16}"
|
| 10 |
+
BIN="./build/qwen3-moe-aclnn"
|
| 11 |
+
N_RUNS="${1:-5}"
|
| 12 |
+
N_PREDICT="${N_PREDICT:-200}"
|
| 13 |
+
PROMPT="The history of artificial intelligence spans several decades and"
|
| 14 |
+
VOCAB="tokenizer_data/vocab.bin"
|
| 15 |
+
|
| 16 |
+
echo "bench_tg: $N_RUNS runs × $N_PREDICT tokens (LCA_WARMUP=${LCA_WARMUP:-0})"
|
| 17 |
+
tgs=()
|
| 18 |
+
for r in $(seq 1 $N_RUNS); do
|
| 19 |
+
local_out=$(./scripts/tp_launch.sh 16 $BIN --model-dir "$MODEL" \
|
| 20 |
+
--prompt "$PROMPT" --n-predict $N_PREDICT \
|
| 21 |
+
--vocab "$VOCAB" --seed 0 2>&1 | grep "decode :" | awk '{print $(NF-2)}')
|
| 22 |
+
printf " run %d: %s t/s\n" "$r" "$local_out"
|
| 23 |
+
tgs+=("${local_out:-0}")
|
| 24 |
+
done
|
| 25 |
+
|
| 26 |
+
echo ""
|
| 27 |
+
echo "====== Summary ======"
|
| 28 |
+
sorted=($(printf '%s\n' "${tgs[@]}" | sort -n))
|
| 29 |
+
n=${#sorted[@]}
|
| 30 |
+
mid=$((n / 2))
|
| 31 |
+
median="${sorted[$mid]}"
|
| 32 |
+
min="${sorted[0]}"
|
| 33 |
+
max="${sorted[-1]}"
|
| 34 |
+
mean=$(printf '%s\n' "${tgs[@]}" | awk '{s+=$1} END {printf "%.2f", s/NR}')
|
| 35 |
+
|
| 36 |
+
echo " all : ${tgs[*]}"
|
| 37 |
+
echo " min : $min t/s"
|
| 38 |
+
echo " median : $median t/s"
|
| 39 |
+
echo " mean : $mean t/s"
|
| 40 |
+
echo " max : $max t/s"
|
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Export Qwen3 tokenizer vocab to a simple binary format.
|
| 3 |
+
|
| 4 |
+
Format (little-endian):
|
| 5 |
+
u32 num_tokens
|
| 6 |
+
for each id in [0, num_tokens):
|
| 7 |
+
u32 byte_length
|
| 8 |
+
u8[byte_length] utf8_bytes
|
| 9 |
+
|
| 10 |
+
Also emits special_tokens.txt with id + content pairs for reference.
|
| 11 |
+
"""
|
| 12 |
+
import json, sys, struct, os
|
| 13 |
+
|
| 14 |
+
model_dir = sys.argv[1] if len(sys.argv) > 1 else '/path/to/Qwen3-235B-A22B-Instruct-2507-BF16'
|
| 15 |
+
out_dir = sys.argv[2] if len(sys.argv) > 2 else 'tokenizer_data'
|
| 16 |
+
os.makedirs(out_dir, exist_ok=True)
|
| 17 |
+
|
| 18 |
+
with open(os.path.join(model_dir, 'tokenizer.json'), 'r') as f:
|
| 19 |
+
tok = json.load(f)
|
| 20 |
+
|
| 21 |
+
# Byte-level decoder map: HF Qwen uses byte-level BPE like GPT-2
|
| 22 |
+
# Each non-ASCII vocab token is a mapping of U+0100..U+017F etc back to raw bytes.
|
| 23 |
+
# For decode we just need the reverse map from printable chars to raw bytes.
|
| 24 |
+
def build_byte_decoder():
|
| 25 |
+
bs = list(range(ord('!'), ord('~')+1)) + list(range(ord('¡'), ord('¬')+1)) + list(range(ord('®'), ord('ÿ')+1))
|
| 26 |
+
cs = bs[:]
|
| 27 |
+
n = 0
|
| 28 |
+
for b in range(2**8):
|
| 29 |
+
if b not in bs:
|
| 30 |
+
bs.append(b)
|
| 31 |
+
cs.append(2**8 + n)
|
| 32 |
+
n += 1
|
| 33 |
+
return {chr(c): bytes([b]) for b, c in zip(bs, cs)}
|
| 34 |
+
|
| 35 |
+
byte_decoder = build_byte_decoder()
|
| 36 |
+
|
| 37 |
+
# Merge vocab + added_tokens into id -> utf8_bytes lookup
|
| 38 |
+
vocab = tok['model']['vocab'] # {token_str: id}
|
| 39 |
+
added = tok.get('added_tokens', []) # list of {id, content, ...}
|
| 40 |
+
|
| 41 |
+
id_to_bytes = {}
|
| 42 |
+
for token, tid in vocab.items():
|
| 43 |
+
# Decode byte-level encoding back to raw utf8 bytes
|
| 44 |
+
raw = b''
|
| 45 |
+
for ch in token:
|
| 46 |
+
if ch in byte_decoder:
|
| 47 |
+
raw += byte_decoder[ch]
|
| 48 |
+
else:
|
| 49 |
+
raw += ch.encode('utf-8')
|
| 50 |
+
id_to_bytes[int(tid)] = raw
|
| 51 |
+
|
| 52 |
+
for a in added:
|
| 53 |
+
# Special tokens stored as raw utf8
|
| 54 |
+
id_to_bytes[int(a['id'])] = a['content'].encode('utf-8')
|
| 55 |
+
|
| 56 |
+
max_id = max(id_to_bytes.keys())
|
| 57 |
+
num = max_id + 1
|
| 58 |
+
print(f"max_id = {max_id}, num_tokens = {num}")
|
| 59 |
+
print(f"num_special_tokens = {len(added)}")
|
| 60 |
+
|
| 61 |
+
# Write vocab.bin
|
| 62 |
+
vocab_path = os.path.join(out_dir, 'vocab.bin')
|
| 63 |
+
with open(vocab_path, 'wb') as f:
|
| 64 |
+
f.write(struct.pack('<I', num))
|
| 65 |
+
for i in range(num):
|
| 66 |
+
b = id_to_bytes.get(i, b'')
|
| 67 |
+
f.write(struct.pack('<I', len(b)))
|
| 68 |
+
f.write(b)
|
| 69 |
+
print(f"Wrote {vocab_path} ({os.path.getsize(vocab_path)} bytes)")
|
| 70 |
+
|
| 71 |
+
# Write special tokens
|
| 72 |
+
with open(os.path.join(out_dir, 'special_tokens.txt'), 'w') as f:
|
| 73 |
+
for a in added:
|
| 74 |
+
f.write(f"{a['id']}\t{a['content']}\n")
|
| 75 |
+
print(f"Wrote special_tokens.txt")
|
| 76 |
+
|
| 77 |
+
# Verify via a known prompt
|
| 78 |
+
from transformers import AutoTokenizer
|
| 79 |
+
atok = AutoTokenizer.from_pretrained(model_dir)
|
| 80 |
+
test = "The capital of France is"
|
| 81 |
+
ids = atok.encode(test)
|
| 82 |
+
print(f"\nTest encode '{test}' -> {ids}")
|
| 83 |
+
decoded = ''.join(id_to_bytes.get(i, b'?').decode('utf-8', errors='replace') for i in ids)
|
| 84 |
+
print(f"Our decode: '{decoded}'")
|
| 85 |
+
print(f"HF decode: '{atok.decode(ids)}'")
|
|
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Generate a single-layer attention forward reference for Qwen3-235B layer 0.
|
| 3 |
+
|
| 4 |
+
Input: token ids (representing "The capital of France is")
|
| 5 |
+
Output: hidden_states after layer 0 attention (residual already added).
|
| 6 |
+
Also dumps all intermediate tensors for step-wise debugging.
|
| 7 |
+
"""
|
| 8 |
+
import os, json, math, struct
|
| 9 |
+
import torch
|
| 10 |
+
import torch_npu
|
| 11 |
+
from safetensors.torch import load_file
|
| 12 |
+
|
| 13 |
+
torch.npu.set_device(0)
|
| 14 |
+
torch.set_grad_enabled(False)
|
| 15 |
+
|
| 16 |
+
MODEL_DIR = '/path/to/Qwen3-235B-A22B-Instruct-2507-BF16'
|
| 17 |
+
OUT_DIR = 'tests/attn_data'
|
| 18 |
+
os.makedirs(OUT_DIR, exist_ok=True)
|
| 19 |
+
|
| 20 |
+
cfg = json.load(open(os.path.join(MODEL_DIR, 'config.json')))
|
| 21 |
+
D = cfg['hidden_size'] # 4096
|
| 22 |
+
Hq = cfg['num_attention_heads'] # 64
|
| 23 |
+
Hkv = cfg['num_key_value_heads'] # 4
|
| 24 |
+
Dh = cfg['head_dim'] # 128
|
| 25 |
+
Q_DIM = Hq * Dh # 8192
|
| 26 |
+
KV_DIM = Hkv * Dh # 512
|
| 27 |
+
eps = cfg['rms_norm_eps']
|
| 28 |
+
theta = cfg['rope_theta'] # 5e6 for Qwen3-235B
|
| 29 |
+
|
| 30 |
+
# ---- Find which safetensors shard contains layer 0 attention + input_layernorm ----
|
| 31 |
+
idx = json.load(open(os.path.join(MODEL_DIR, 'model.safetensors.index.json')))
|
| 32 |
+
wm = idx['weight_map']
|
| 33 |
+
|
| 34 |
+
needed = [
|
| 35 |
+
'model.embed_tokens.weight',
|
| 36 |
+
'model.layers.0.input_layernorm.weight',
|
| 37 |
+
'model.layers.0.self_attn.q_proj.weight',
|
| 38 |
+
'model.layers.0.self_attn.k_proj.weight',
|
| 39 |
+
'model.layers.0.self_attn.v_proj.weight',
|
| 40 |
+
'model.layers.0.self_attn.o_proj.weight',
|
| 41 |
+
'model.layers.0.self_attn.q_norm.weight',
|
| 42 |
+
'model.layers.0.self_attn.k_norm.weight',
|
| 43 |
+
]
|
| 44 |
+
shards = sorted({wm[n] for n in needed})
|
| 45 |
+
print("Need to load shards:", shards)
|
| 46 |
+
|
| 47 |
+
weights = {}
|
| 48 |
+
for sh in shards:
|
| 49 |
+
t = load_file(os.path.join(MODEL_DIR, sh))
|
| 50 |
+
for n in needed:
|
| 51 |
+
if n in t:
|
| 52 |
+
weights[n] = t[n].to('npu')
|
| 53 |
+
print("loaded:", list(weights.keys()))
|
| 54 |
+
|
| 55 |
+
# ---- Forward ----
|
| 56 |
+
# Input tokens (from tokenizer: "The capital of France is")
|
| 57 |
+
token_ids = torch.tensor([785, 6722, 315, 9625, 374], dtype=torch.long).npu()
|
| 58 |
+
S = token_ids.shape[0]
|
| 59 |
+
print(f"S = {S}")
|
| 60 |
+
|
| 61 |
+
# Embedding lookup
|
| 62 |
+
x = weights['model.embed_tokens.weight'][token_ids] # [S, D]
|
| 63 |
+
x = x.unsqueeze(0) # [1, S, D]
|
| 64 |
+
print("embed x:", x.shape, x.dtype)
|
| 65 |
+
|
| 66 |
+
# Residual
|
| 67 |
+
residual = x
|
| 68 |
+
|
| 69 |
+
# Input layernorm (RMSNorm)
|
| 70 |
+
ln = weights['model.layers.0.input_layernorm.weight']
|
| 71 |
+
xn, _ = torch_npu.npu_rms_norm(x, ln, epsilon=eps)
|
| 72 |
+
print("after_input_norm xn:", xn.shape)
|
| 73 |
+
|
| 74 |
+
# Q/K/V projections
|
| 75 |
+
Wq = weights['model.layers.0.self_attn.q_proj.weight']
|
| 76 |
+
Wk = weights['model.layers.0.self_attn.k_proj.weight']
|
| 77 |
+
Wv = weights['model.layers.0.self_attn.v_proj.weight']
|
| 78 |
+
q = torch.matmul(xn, Wq.t()) # [1, S, Q_DIM]
|
| 79 |
+
k = torch.matmul(xn, Wk.t()) # [1, S, KV_DIM]
|
| 80 |
+
v = torch.matmul(xn, Wv.t())
|
| 81 |
+
|
| 82 |
+
# Reshape to heads
|
| 83 |
+
q = q.view(1, S, Hq, Dh)
|
| 84 |
+
k = k.view(1, S, Hkv, Dh)
|
| 85 |
+
v = v.view(1, S, Hkv, Dh)
|
| 86 |
+
|
| 87 |
+
# Per-head RMSNorm on head_dim (Qwen3 specific)
|
| 88 |
+
qn_w = weights['model.layers.0.self_attn.q_norm.weight'] # [Dh]
|
| 89 |
+
kn_w = weights['model.layers.0.self_attn.k_norm.weight']
|
| 90 |
+
q_normed, _ = torch_npu.npu_rms_norm(q, qn_w, epsilon=eps)
|
| 91 |
+
k_normed, _ = torch_npu.npu_rms_norm(k, kn_w, epsilon=eps)
|
| 92 |
+
|
| 93 |
+
# RoPE: compute cos/sin for positions [0, S)
|
| 94 |
+
position_ids = torch.arange(S, device='npu').unsqueeze(0) # [1, S]
|
| 95 |
+
inv_freq = 1.0 / (theta ** (torch.arange(0, Dh, 2, dtype=torch.float32).npu() / Dh))
|
| 96 |
+
freqs = position_ids.float().unsqueeze(-1) * inv_freq.unsqueeze(0).unsqueeze(0) # [1, S, Dh/2]
|
| 97 |
+
# Concat (half, half) to get [1, S, Dh]
|
| 98 |
+
emb = torch.cat([freqs, freqs], dim=-1)
|
| 99 |
+
cos = emb.cos().to(torch.bfloat16) # [1, S, Dh]
|
| 100 |
+
sin = emb.sin().to(torch.bfloat16)
|
| 101 |
+
|
| 102 |
+
# Apply RoPE — npu_apply_rotary_pos_emb expects BSND layout
|
| 103 |
+
# cos/sin shape: [1, S, 1, Dh] for broadcast over heads
|
| 104 |
+
cos_b = cos.unsqueeze(2)
|
| 105 |
+
sin_b = sin.unsqueeze(2)
|
| 106 |
+
q_roped, k_roped = torch_npu.npu_apply_rotary_pos_emb(q_normed, k_normed, cos_b, sin_b)
|
| 107 |
+
|
| 108 |
+
# Flatten for FIAS (BSH layout)
|
| 109 |
+
q_bsh = q_roped.reshape(1, S, Q_DIM)
|
| 110 |
+
k_bsh = k_roped.reshape(1, S, KV_DIM)
|
| 111 |
+
v_bsh = v.reshape(1, S, KV_DIM)
|
| 112 |
+
|
| 113 |
+
# FIAS with causal mask for prefill
|
| 114 |
+
scale = 1.0 / math.sqrt(Dh)
|
| 115 |
+
# sparse_mode=3 requires fixed 2048×2048 mask
|
| 116 |
+
MASK_SIZE = 2048
|
| 117 |
+
mask = torch.triu(torch.ones(MASK_SIZE, MASK_SIZE, dtype=torch.bool, device='npu'), diagonal=1)
|
| 118 |
+
mask = mask.view(1, 1, MASK_SIZE, MASK_SIZE)
|
| 119 |
+
attn_out, _ = torch_npu.npu_fused_infer_attention_score(
|
| 120 |
+
q_bsh, k_bsh, v_bsh,
|
| 121 |
+
num_heads=Hq,
|
| 122 |
+
num_key_value_heads=Hkv,
|
| 123 |
+
scale=scale,
|
| 124 |
+
input_layout="BSH",
|
| 125 |
+
sparse_mode=3,
|
| 126 |
+
atten_mask=mask,
|
| 127 |
+
actual_seq_lengths=[S],
|
| 128 |
+
actual_seq_lengths_kv=[S],
|
| 129 |
+
)
|
| 130 |
+
print("attn_out:", attn_out.shape) # [1, S, Q_DIM]
|
| 131 |
+
|
| 132 |
+
# Output projection
|
| 133 |
+
Wo = weights['model.layers.0.self_attn.o_proj.weight']
|
| 134 |
+
o = torch.matmul(attn_out, Wo.t()) # [1, S, D]
|
| 135 |
+
|
| 136 |
+
# Residual add
|
| 137 |
+
out = residual + o
|
| 138 |
+
print("out:", out.shape, out[0, 0, :4].float().tolist())
|
| 139 |
+
|
| 140 |
+
# ---- Dump ----
|
| 141 |
+
def dump(name, t):
|
| 142 |
+
p = os.path.join(OUT_DIR, name + '.bin')
|
| 143 |
+
a = t.contiguous().cpu().view(torch.int16).numpy().astype('int16')
|
| 144 |
+
open(p, 'wb').write(a.tobytes())
|
| 145 |
+
|
| 146 |
+
# Save token_ids
|
| 147 |
+
with open(os.path.join(OUT_DIR, 'token_ids.bin'), 'wb') as f:
|
| 148 |
+
f.write(struct.pack('<i', S))
|
| 149 |
+
for tid in token_ids.cpu().tolist():
|
| 150 |
+
f.write(struct.pack('<i', tid))
|
| 151 |
+
|
| 152 |
+
# Save inputs
|
| 153 |
+
dump('x_input', x) # embed result
|
| 154 |
+
dump('x_normed', xn)
|
| 155 |
+
dump('q_normed', q_normed)
|
| 156 |
+
dump('k_normed', k_normed)
|
| 157 |
+
dump('q_roped', q_roped)
|
| 158 |
+
dump('k_roped', k_roped)
|
| 159 |
+
dump('cos', cos)
|
| 160 |
+
dump('sin', sin)
|
| 161 |
+
dump('attn_out', attn_out)
|
| 162 |
+
dump('final_out', out)
|
| 163 |
+
# Save weights used (dtype=BF16)
|
| 164 |
+
for name, path_name in [
|
| 165 |
+
('model.layers.0.input_layernorm.weight', 'w_input_norm'),
|
| 166 |
+
('model.layers.0.self_attn.q_proj.weight', 'w_q_proj'),
|
| 167 |
+
('model.layers.0.self_attn.k_proj.weight', 'w_k_proj'),
|
| 168 |
+
('model.layers.0.self_attn.v_proj.weight', 'w_v_proj'),
|
| 169 |
+
('model.layers.0.self_attn.o_proj.weight', 'w_o_proj'),
|
| 170 |
+
('model.layers.0.self_attn.q_norm.weight', 'w_q_norm'),
|
| 171 |
+
('model.layers.0.self_attn.k_norm.weight', 'w_k_norm'),
|
| 172 |
+
]:
|
| 173 |
+
dump(path_name, weights[name])
|
| 174 |
+
|
| 175 |
+
with open(os.path.join(OUT_DIR, 'shape.txt'), 'w') as f:
|
| 176 |
+
f.write(f"S={S}\nD={D}\nHq={Hq}\nHkv={Hkv}\nDh={Dh}\nQ_DIM={Q_DIM}\nKV_DIM={KV_DIM}\neps={eps}\ntheta={theta}\n")
|
| 177 |
+
|
| 178 |
+
print("\nAll dumps in:", OUT_DIR)
|
| 179 |
+
print("Final output first 4:", out[0, 0, :4].float().cpu().tolist())
|
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Generate a GMM reference case using torch_npu.npu_grouped_matmul.
|
| 3 |
+
|
| 4 |
+
Dumps: x, w (unpermuted ggml-like layout), group_list, and y_ref to binary files
|
| 5 |
+
for the C++ POC to load and validate against.
|
| 6 |
+
"""
|
| 7 |
+
import os
|
| 8 |
+
import sys
|
| 9 |
+
import numpy as np
|
| 10 |
+
import struct
|
| 11 |
+
|
| 12 |
+
# Enable torch_npu
|
| 13 |
+
os.environ.setdefault('LD_LIBRARY_PATH', '')
|
| 14 |
+
import torch
|
| 15 |
+
import torch_npu
|
| 16 |
+
|
| 17 |
+
torch.npu.set_device(0)
|
| 18 |
+
torch.manual_seed(42)
|
| 19 |
+
|
| 20 |
+
# Toy Qwen3-like MoE shape: D=64 hidden, I=32 intermediate, E=8 experts, N*K=16 expanded tokens
|
| 21 |
+
# (small enough to eyeball; large enough to catch layout bugs)
|
| 22 |
+
D, I, E, TOTAL = 64, 32, 8, 16
|
| 23 |
+
|
| 24 |
+
# Input x: [TOTAL, D] BF16 — expanded routed tokens
|
| 25 |
+
x = torch.randn(TOTAL, D, dtype=torch.bfloat16).npu()
|
| 26 |
+
|
| 27 |
+
# Weight w: per-expert [I, D] BF16 — gate/up has this shape in HF
|
| 28 |
+
# We will stack into [E, I, D] and also provide [E, D, I] permuted for comparison
|
| 29 |
+
w_per_expert = [torch.randn(I, D, dtype=torch.bfloat16).npu() for _ in range(E)]
|
| 30 |
+
w_stacked_IDL = torch.stack(w_per_expert, dim=0) # [E, I, D]
|
| 31 |
+
|
| 32 |
+
# group_list: counts of tokens per expert, sum = TOTAL
|
| 33 |
+
group_list = torch.tensor([3, 2, 1, 2, 1, 3, 2, 2], dtype=torch.int64).npu()
|
| 34 |
+
assert group_list.sum().item() == TOTAL
|
| 35 |
+
|
| 36 |
+
# Reference: use torch_npu.npu_grouped_matmul
|
| 37 |
+
# Per cann-recipes: weight needs to be in [E, D, I] for matmul y = x @ w (y shape [total, I])
|
| 38 |
+
# i.e. per-expert w is transposed from HF's [I, D] to [D, I]
|
| 39 |
+
w_transposed = w_stacked_IDL.transpose(1, 2).contiguous() # [E, D, I]
|
| 40 |
+
|
| 41 |
+
# Call GMM: y = x @ w, result [TOTAL, I]
|
| 42 |
+
y_ref = torch_npu.npu_grouped_matmul(
|
| 43 |
+
[x], # x list
|
| 44 |
+
[w_transposed], # weight list (transposed)
|
| 45 |
+
group_list=group_list,
|
| 46 |
+
group_type=0,
|
| 47 |
+
group_list_type=1, # counts
|
| 48 |
+
split_item=3 # single-in single-out
|
| 49 |
+
)[0] # unwrap tensor list
|
| 50 |
+
|
| 51 |
+
print("x shape:", x.shape, x.dtype)
|
| 52 |
+
print("w_stacked_IDL shape:", w_stacked_IDL.shape, w_stacked_IDL.dtype)
|
| 53 |
+
print("w_transposed shape:", w_transposed.shape)
|
| 54 |
+
print("group_list:", group_list.cpu().tolist())
|
| 55 |
+
print("y_ref shape:", y_ref.shape)
|
| 56 |
+
print("y_ref[0, 0:4]:", y_ref[0, 0:4].cpu().float().tolist())
|
| 57 |
+
|
| 58 |
+
# Save binary dumps
|
| 59 |
+
out_dir = 'tests/poc_data'
|
| 60 |
+
os.makedirs(out_dir, exist_ok=True)
|
| 61 |
+
|
| 62 |
+
def dump_bf16(name, tensor):
|
| 63 |
+
path = os.path.join(out_dir, name + '.bin')
|
| 64 |
+
arr = tensor.contiguous().cpu().view(torch.int16).numpy().astype('int16')
|
| 65 |
+
with open(path, 'wb') as f:
|
| 66 |
+
f.write(arr.tobytes())
|
| 67 |
+
print(f" wrote {name}.bin: {arr.shape} int16 = BF16 raw, {arr.nbytes} bytes")
|
| 68 |
+
|
| 69 |
+
def dump_int64(name, tensor):
|
| 70 |
+
path = os.path.join(out_dir, name + '.bin')
|
| 71 |
+
arr = tensor.contiguous().cpu().numpy().astype('int64')
|
| 72 |
+
with open(path, 'wb') as f:
|
| 73 |
+
f.write(arr.tobytes())
|
| 74 |
+
print(f" wrote {name}.bin: {arr.shape} int64, {arr.nbytes} bytes")
|
| 75 |
+
|
| 76 |
+
# HF-style weight layout (ggml stores similar): [E, I, D] = what C++ gets from safetensors after stack
|
| 77 |
+
dump_bf16('x', x)
|
| 78 |
+
dump_bf16('w_hf_EID', w_stacked_IDL) # C++ input weight (HF layout)
|
| 79 |
+
dump_bf16('w_ref_EDI', w_transposed) # Already-permuted reference (for debug)
|
| 80 |
+
dump_int64('group_list', group_list)
|
| 81 |
+
dump_bf16('y_ref', y_ref)
|
| 82 |
+
|
| 83 |
+
# Also dump shapes header
|
| 84 |
+
with open(os.path.join(out_dir, 'shapes.txt'), 'w') as f:
|
| 85 |
+
f.write(f"D={D}\nI={I}\nE={E}\nTOTAL={TOTAL}\n")
|
| 86 |
+
|
| 87 |
+
print("\nAll dumps in:", out_dir)
|
| 88 |
+
print("\nTo validate: C++ loads w_hf_EID, permutes [0,2,1] to [E,D,I], NZ-casts, calls GMMV4, "
|
| 89 |
+
"compares output to y_ref.")
|
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Generate a linear (y = x @ W.T) reference for a realistic Qwen3 attention shape."""
|
| 3 |
+
import os, struct, torch, torch_npu
|
| 4 |
+
torch.npu.set_device(0)
|
| 5 |
+
torch.manual_seed(7)
|
| 6 |
+
|
| 7 |
+
N, D, OUT = 5, 4096, 8192 # prompt len, hidden, q_dim
|
| 8 |
+
|
| 9 |
+
x = torch.randn(N, D, dtype=torch.bfloat16).npu()
|
| 10 |
+
W = torch.randn(OUT, D, dtype=torch.bfloat16).npu() # HF layout [out, in]
|
| 11 |
+
# y = x @ W.T, shape [N, OUT]
|
| 12 |
+
y_ref = torch.matmul(x, W.t())
|
| 13 |
+
|
| 14 |
+
out_dir = 'tests/mm_data'
|
| 15 |
+
os.makedirs(out_dir, exist_ok=True)
|
| 16 |
+
def dump(name, t):
|
| 17 |
+
p = os.path.join(out_dir, name + '.bin')
|
| 18 |
+
a = t.contiguous().cpu().view(torch.int16).numpy().astype('int16')
|
| 19 |
+
open(p, 'wb').write(a.tobytes())
|
| 20 |
+
dump('x', x); dump('W', W); dump('y_ref', y_ref)
|
| 21 |
+
with open(os.path.join(out_dir, 'shape.txt'), 'w') as f:
|
| 22 |
+
f.write(f"N={N}\nD={D}\nOUT={OUT}\n")
|
| 23 |
+
print(f"N={N} D={D} OUT={OUT}, y_ref[0, :4] = {y_ref[0, :4].float().cpu().tolist()}")
|
|
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Generate MoE layer forward reference for Qwen3-235B layer 0.
|
| 3 |
+
|
| 4 |
+
Input: hidden_states from attention output (use attn_data/final_out.bin as input — realistic).
|
| 5 |
+
Output: hidden_states after MoE + residual.
|
| 6 |
+
"""
|
| 7 |
+
import os, json, math, torch, torch_npu
|
| 8 |
+
from safetensors.torch import load_file
|
| 9 |
+
|
| 10 |
+
torch.npu.set_device(0)
|
| 11 |
+
torch.set_grad_enabled(False)
|
| 12 |
+
|
| 13 |
+
MODEL_DIR = '/path/to/Qwen3-235B-A22B-Instruct-2507-BF16'
|
| 14 |
+
OUT_DIR = 'tests/moe_data'
|
| 15 |
+
os.makedirs(OUT_DIR, exist_ok=True)
|
| 16 |
+
|
| 17 |
+
cfg = json.load(open(os.path.join(MODEL_DIR, 'config.json')))
|
| 18 |
+
D = cfg['hidden_size'] # 4096
|
| 19 |
+
I = cfg['moe_intermediate_size'] # 1536
|
| 20 |
+
E = cfg['num_experts'] # 128
|
| 21 |
+
TK = cfg['num_experts_per_tok'] # 8
|
| 22 |
+
eps = cfg['rms_norm_eps']
|
| 23 |
+
norm_topk = cfg.get('norm_topk_prob', True)
|
| 24 |
+
|
| 25 |
+
# Use attention output as input (more realistic than random)
|
| 26 |
+
attn_out_raw = open('tests/attn_data/final_out.bin', 'rb').read()
|
| 27 |
+
S = 5
|
| 28 |
+
x_in = torch.frombuffer(bytearray(attn_out_raw), dtype=torch.int16).view(1, S, D).view(torch.bfloat16).npu()
|
| 29 |
+
print(f"x_in: {x_in.shape}")
|
| 30 |
+
|
| 31 |
+
# Load required weights for layer 0
|
| 32 |
+
idx = json.load(open(os.path.join(MODEL_DIR, 'model.safetensors.index.json')))
|
| 33 |
+
wm = idx['weight_map']
|
| 34 |
+
|
| 35 |
+
needed = [f'model.layers.0.post_attention_layernorm.weight',
|
| 36 |
+
f'model.layers.0.mlp.gate.weight']
|
| 37 |
+
for e in range(E):
|
| 38 |
+
for p in ['gate_proj', 'up_proj', 'down_proj']:
|
| 39 |
+
needed.append(f'model.layers.0.mlp.experts.{e}.{p}.weight')
|
| 40 |
+
|
| 41 |
+
shards = sorted({wm[n] for n in needed})
|
| 42 |
+
weights = {}
|
| 43 |
+
for sh in shards:
|
| 44 |
+
t = load_file(os.path.join(MODEL_DIR, sh))
|
| 45 |
+
for n in needed:
|
| 46 |
+
if n in t:
|
| 47 |
+
weights[n] = t[n].to('npu')
|
| 48 |
+
print("loaded %d tensors from %d shards" % (len(weights), len(shards)))
|
| 49 |
+
|
| 50 |
+
# Residual = input
|
| 51 |
+
residual = x_in
|
| 52 |
+
|
| 53 |
+
# Post-attention RmsNorm
|
| 54 |
+
xn, _ = torch_npu.npu_rms_norm(x_in, weights['model.layers.0.post_attention_layernorm.weight'], epsilon=eps)
|
| 55 |
+
xn_flat = xn.view(S, D) # flatten batch
|
| 56 |
+
|
| 57 |
+
# Router: logits [S, E]
|
| 58 |
+
W_router = weights['model.layers.0.mlp.gate.weight'] # [E, D]
|
| 59 |
+
logits = xn_flat @ W_router.t() # [S, E]
|
| 60 |
+
|
| 61 |
+
# Top-k softmax
|
| 62 |
+
topk_logits, topk_idx = logits.topk(TK, dim=-1) # both [S, TK]
|
| 63 |
+
topk_weights = torch.softmax(topk_logits.float(), dim=-1) # [S, TK] F32
|
| 64 |
+
if norm_topk:
|
| 65 |
+
topk_weights = topk_weights / (topk_weights.sum(dim=-1, keepdim=True) + 1e-20)
|
| 66 |
+
topk_weights = topk_weights.to(torch.bfloat16)
|
| 67 |
+
topk_idx = topk_idx.to(torch.int32)
|
| 68 |
+
|
| 69 |
+
print(f"topk_idx[0]: {topk_idx[0].cpu().tolist()}")
|
| 70 |
+
print(f"topk_weights[0]: {topk_weights[0].cpu().float().tolist()}")
|
| 71 |
+
|
| 72 |
+
# MoE forward — loop over tokens (simple reference, not optimized)
|
| 73 |
+
out_flat = torch.zeros(S, D, dtype=torch.bfloat16, device='npu')
|
| 74 |
+
for s in range(S):
|
| 75 |
+
token = xn_flat[s] # [D]
|
| 76 |
+
acc = torch.zeros(D, dtype=torch.bfloat16, device='npu')
|
| 77 |
+
for k in range(TK):
|
| 78 |
+
e = int(topk_idx[s, k].item())
|
| 79 |
+
w = topk_weights[s, k]
|
| 80 |
+
Wg = weights[f'model.layers.0.mlp.experts.{e}.gate_proj.weight'] # [I, D]
|
| 81 |
+
Wu = weights[f'model.layers.0.mlp.experts.{e}.up_proj.weight'] # [I, D]
|
| 82 |
+
Wd = weights[f'model.layers.0.mlp.experts.{e}.down_proj.weight'] # [D, I]
|
| 83 |
+
gate = token @ Wg.t() # [I]
|
| 84 |
+
up = token @ Wu.t()
|
| 85 |
+
act = torch.nn.functional.silu(gate) * up
|
| 86 |
+
down = act @ Wd.t() # [D]
|
| 87 |
+
acc = acc + w * down
|
| 88 |
+
out_flat[s] = acc
|
| 89 |
+
|
| 90 |
+
moe_out = out_flat.view(1, S, D)
|
| 91 |
+
final_out = residual + moe_out
|
| 92 |
+
print(f"final_out[0,0,:4] = {final_out[0,0,:4].float().cpu().tolist()}")
|
| 93 |
+
|
| 94 |
+
# Dump
|
| 95 |
+
def dump(name, t):
|
| 96 |
+
p = os.path.join(OUT_DIR, name + '.bin')
|
| 97 |
+
a = t.contiguous().cpu().view(torch.int16).numpy().astype('int16')
|
| 98 |
+
open(p, 'wb').write(a.tobytes())
|
| 99 |
+
|
| 100 |
+
dump('x_in', x_in)
|
| 101 |
+
dump('final_out', final_out)
|
| 102 |
+
dump('moe_out', moe_out)
|
| 103 |
+
dump('router', W_router)
|
| 104 |
+
dump('xn', xn)
|
| 105 |
+
dump('topk_w', topk_weights) # [S, TK] BF16 (normalized)
|
| 106 |
+
dump('out_flat', out_flat) # [S, D] BF16 — moe contrib before residual
|
| 107 |
+
|
| 108 |
+
# expert_idx as int32 dump (raw bytes)
|
| 109 |
+
topk_idx_bytes = topk_idx.contiguous().cpu().numpy().astype('int32').tobytes()
|
| 110 |
+
open(os.path.join(OUT_DIR, 'topk_idx.bin'), 'wb').write(topk_idx_bytes)
|
| 111 |
+
|
| 112 |
+
with open(os.path.join(OUT_DIR, 'shape.txt'), 'w') as f:
|
| 113 |
+
f.write(f"S={S}\nD={D}\nI={I}\nE={E}\nTK={TK}\n")
|
| 114 |
+
|
| 115 |
+
print(f"\nDumps in {OUT_DIR}")
|
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Generate a RmsNorm reference using PyTorch."""
|
| 3 |
+
import os, struct
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
import torch_npu
|
| 7 |
+
|
| 8 |
+
torch.npu.set_device(0)
|
| 9 |
+
torch.manual_seed(123)
|
| 10 |
+
|
| 11 |
+
N, D = 5, 4096 # 5 tokens, Qwen3 hidden_size
|
| 12 |
+
eps = 1e-6
|
| 13 |
+
|
| 14 |
+
x = torch.randn(N, D, dtype=torch.bfloat16).npu()
|
| 15 |
+
gamma = torch.randn(D, dtype=torch.bfloat16).npu() * 0.1 + 1.0
|
| 16 |
+
|
| 17 |
+
# Use torch_npu's npu_rms_norm if available, else do it manually
|
| 18 |
+
y_ref, _ = torch_npu.npu_rms_norm(x, gamma, epsilon=eps)
|
| 19 |
+
|
| 20 |
+
out_dir = 'tests/rms_norm_data'
|
| 21 |
+
os.makedirs(out_dir, exist_ok=True)
|
| 22 |
+
|
| 23 |
+
def dump_bf16(name, t):
|
| 24 |
+
path = os.path.join(out_dir, name + '.bin')
|
| 25 |
+
a = t.contiguous().cpu().view(torch.int16).numpy().astype('int16')
|
| 26 |
+
with open(path, 'wb') as f:
|
| 27 |
+
f.write(a.tobytes())
|
| 28 |
+
return path
|
| 29 |
+
|
| 30 |
+
dump_bf16('x', x)
|
| 31 |
+
dump_bf16('gamma', gamma)
|
| 32 |
+
dump_bf16('y_ref', y_ref)
|
| 33 |
+
|
| 34 |
+
with open(os.path.join(out_dir, 'shape.txt'), 'w') as f:
|
| 35 |
+
f.write(f"N={N}\nD={D}\neps={eps}\n")
|
| 36 |
+
|
| 37 |
+
print(f"x shape: {x.shape}, gamma: {gamma.shape}, y_ref: {y_ref.shape}")
|
| 38 |
+
print("y_ref[0, :8]:", y_ref[0, :8].float().cpu().tolist())
|
| 39 |
+
print("saved in", out_dir)
|
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Re-generate RoPE reference using explicit HF formula (not torch_npu.npu_apply_rotary_pos_emb)."""
|
| 3 |
+
import os, math, torch, torch_npu
|
| 4 |
+
torch.npu.set_device(0)
|
| 5 |
+
torch.manual_seed(42)
|
| 6 |
+
|
| 7 |
+
S = 5; Hq = 64; Hkv = 4; Dh = 128
|
| 8 |
+
theta = 5e6
|
| 9 |
+
data = 'tests/attn_data'
|
| 10 |
+
|
| 11 |
+
def load_bf16(name, shape):
|
| 12 |
+
raw = open(os.path.join(data, name + '.bin'), 'rb').read()
|
| 13 |
+
a = torch.frombuffer(bytearray(raw), dtype=torch.int16).view(*shape).view(torch.bfloat16)
|
| 14 |
+
return a.npu()
|
| 15 |
+
|
| 16 |
+
q = load_bf16('q_normed', [1, S, Hq, Dh])
|
| 17 |
+
k = load_bf16('k_normed', [1, S, Hkv, Dh])
|
| 18 |
+
|
| 19 |
+
# Compute cos/sin identical to HF (rope_theta=5e6, 0..S positions)
|
| 20 |
+
inv_freq = 1.0 / (theta ** (torch.arange(0, Dh, 2, dtype=torch.float32).npu() / Dh))
|
| 21 |
+
pos = torch.arange(S, device='npu').float().unsqueeze(-1)
|
| 22 |
+
freqs = pos * inv_freq
|
| 23 |
+
emb = torch.cat([freqs, freqs], dim=-1) # [S, Dh]
|
| 24 |
+
cos = emb.cos().to(torch.bfloat16) # [S, Dh]
|
| 25 |
+
sin = emb.sin().to(torch.bfloat16)
|
| 26 |
+
|
| 27 |
+
# HF (Qwen3) style RoPE: q_rot = q * cos + rotate_half(q) * sin
|
| 28 |
+
def rotate_half(x):
|
| 29 |
+
h = x.shape[-1] // 2
|
| 30 |
+
x1 = x[..., :h]
|
| 31 |
+
x2 = x[..., h:]
|
| 32 |
+
return torch.cat([-x2, x1], dim=-1)
|
| 33 |
+
|
| 34 |
+
# Broadcast cos/sin from [S, Dh] to [1, S, 1, Dh]
|
| 35 |
+
cos_b = cos.unsqueeze(0).unsqueeze(2)
|
| 36 |
+
sin_b = sin.unsqueeze(0).unsqueeze(2)
|
| 37 |
+
|
| 38 |
+
q_roped_hf = q * cos_b + rotate_half(q) * sin_b
|
| 39 |
+
k_roped_hf = k * cos_b + rotate_half(k) * sin_b
|
| 40 |
+
|
| 41 |
+
print("HF-style q_roped[0,0,:4]:", q_roped_hf[0,0,0,:4].float().cpu().tolist())
|
| 42 |
+
print("cos[0,:4]:", cos[0,:4].float().cpu().tolist())
|
| 43 |
+
print("sin[0,:4]:", sin[0,:4].float().cpu().tolist())
|
| 44 |
+
print("cos[1,:4]:", cos[1,:4].float().cpu().tolist())
|
| 45 |
+
|
| 46 |
+
# Compare with existing q_roped (from torch_npu.npu_apply_rotary_pos_emb)
|
| 47 |
+
old_q_roped = load_bf16('q_roped', [1, S, Hq, Dh])
|
| 48 |
+
diff = (q_roped_hf - old_q_roped).float().abs().max().item()
|
| 49 |
+
print(f"\nDiff between HF formula and npu_apply: max={diff:.4f}")
|
| 50 |
+
|
| 51 |
+
# Save HF version as ground truth
|
| 52 |
+
def dump(name, t):
|
| 53 |
+
p = os.path.join(data, name + '.bin')
|
| 54 |
+
a = t.contiguous().cpu().view(torch.int16).numpy().astype('int16')
|
| 55 |
+
open(p, 'wb').write(a.tobytes())
|
| 56 |
+
dump('q_roped', q_roped_hf)
|
| 57 |
+
dump('k_roped', k_roped_hf)
|
| 58 |
+
# Overwrite cos, sin to [1, S, Dh] layout
|
| 59 |
+
dump('cos', cos.unsqueeze(0)) # [1, S, Dh]
|
| 60 |
+
dump('sin', sin.unsqueeze(0))
|
| 61 |
+
|
| 62 |
+
print("\nOverwrote q_roped, k_roped, cos, sin with HF-formula ground truth.")
|
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
# tp_launch.sh — launcher for TP>1 multi-process qwen3-moe-aclnn.
|
| 3 |
+
#
|
| 4 |
+
# Usage: ./tp_launch.sh <tp_size> <bin> [args...]
|
| 5 |
+
# e.g. ./tp_launch.sh 16 ./build/qwen3-moe-aclnn --model-dir ... --prompt "..." --n-predict 20
|
| 6 |
+
#
|
| 7 |
+
# Each rank runs as a separate process with:
|
| 8 |
+
# ASCEND_RT_VISIBLE_DEVICES=<rank>
|
| 9 |
+
# TP_RANK=<rank> TP_SIZE=<tp_size>
|
| 10 |
+
# HCCL_WHITELIST_DISABLE=1
|
| 11 |
+
# rank 0 creates /tmp/hccl_root_info.bin; other ranks wait for it.
|
| 12 |
+
set -euo pipefail
|
| 13 |
+
|
| 14 |
+
TP_SIZE="${1:?tp_size required}"; shift
|
| 15 |
+
BIN="${1:?binary required}"; shift
|
| 16 |
+
|
| 17 |
+
# Clean any stale HCCL coordination file
|
| 18 |
+
rm -f /tmp/hccl_root_info.bin
|
| 19 |
+
|
| 20 |
+
export HCCL_WHITELIST_DISABLE=1
|
| 21 |
+
# Benchmark-tuned defaults (bench_hccl_adv.sh 2026-04-21):
|
| 22 |
+
# ring:200 + OP_EXPANSION_MODE=AIV + OP_BASE_FFTS_MODE_ENABLE=1 → ~18.8 t/s median
|
| 23 |
+
# vs baseline (auto) ~12 t/s. +54% from HCCL env knobs alone.
|
| 24 |
+
export HCCL_ALGO="${HCCL_ALGO:-level0:ring}"
|
| 25 |
+
export HCCL_BUFFSIZE="${HCCL_BUFFSIZE:-200}"
|
| 26 |
+
export HCCL_OP_EXPANSION_MODE="${HCCL_OP_EXPANSION_MODE:-AIV}"
|
| 27 |
+
export HCCL_OP_BASE_FFTS_MODE_ENABLE="${HCCL_OP_BASE_FFTS_MODE_ENABLE:-1}"
|
| 28 |
+
# TASK_QUEUE_ENABLE=2: aggressive async task queueing (marginal gain on top of AIV+FFTS)
|
| 29 |
+
export TASK_QUEUE_ENABLE="${TASK_QUEUE_ENABLE:-2}"
|
| 30 |
+
|
| 31 |
+
# Launch ranks 1..N-1 in background with stdin/stdout redirected to /dev/null / logfile.
|
| 32 |
+
# Launch rank 0 LAST in foreground, inheriting the terminal stdin/stdout — so --interactive works.
|
| 33 |
+
pids=()
|
| 34 |
+
for rank in $(seq 1 $((TP_SIZE - 1))); do
|
| 35 |
+
logfile="/tmp/tp_rank_${rank}.log"
|
| 36 |
+
env ASCEND_RT_VISIBLE_DEVICES=${rank} \
|
| 37 |
+
TP_RANK=${rank} \
|
| 38 |
+
TP_SIZE=${TP_SIZE} \
|
| 39 |
+
"${BIN}" "$@" < /dev/null > "${logfile}" 2>&1 &
|
| 40 |
+
pids+=($!)
|
| 41 |
+
echo "[tp_launch] rank ${rank} pid=$! log=${logfile}"
|
| 42 |
+
done
|
| 43 |
+
|
| 44 |
+
# Give ranks 1..N-1 a moment to reach HcclCommInitRootInfo's file-wait before rank 0 writes it.
|
| 45 |
+
sleep 1
|
| 46 |
+
|
| 47 |
+
# Rank 0 in foreground — terminal stdin/stdout passthrough for REPL.
|
| 48 |
+
env ASCEND_RT_VISIBLE_DEVICES=0 \
|
| 49 |
+
TP_RANK=0 \
|
| 50 |
+
TP_SIZE=${TP_SIZE} \
|
| 51 |
+
"${BIN}" "$@"
|
| 52 |
+
ec=$?
|
| 53 |
+
|
| 54 |
+
# Wait for background ranks to finish (rank 0 exit signals end-of-work, but they may take a bit).
|
| 55 |
+
for i in "${!pids[@]}"; do
|
| 56 |
+
wait "${pids[$i]}" || true
|
| 57 |
+
done
|
| 58 |
+
exit $ec
|
|
@@ -0,0 +1,221 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include "device_weights.h"
|
| 2 |
+
#include "aclnn_ops.h"
|
| 3 |
+
#include <cstdio>
|
| 4 |
+
#include <cstring>
|
| 5 |
+
#include <vector>
|
| 6 |
+
|
| 7 |
+
bool DeviceWeightsLoader::load_tensor_full_(const std::string& name, DeviceBuffer& buf) {
|
| 8 |
+
const auto* m = st_.get(name);
|
| 9 |
+
if (!m) { fprintf(stderr, "load_tensor_full_: missing %s\n", name.c_str()); return false; }
|
| 10 |
+
const void* host = st_.data_ptr(*m);
|
| 11 |
+
if (!host) { fprintf(stderr, "load_tensor_full_: null host ptr %s\n", name.c_str()); return false; }
|
| 12 |
+
buf.alloc(m->nbytes);
|
| 13 |
+
ACL_CHECK(aclrtMemcpy(buf.get(), m->nbytes, host, m->nbytes, ACL_MEMCPY_HOST_TO_DEVICE));
|
| 14 |
+
return true;
|
| 15 |
+
}
|
| 16 |
+
|
| 17 |
+
bool DeviceWeightsLoader::load_tensor_row_slice_(const std::string& name,
|
| 18 |
+
int64_t row_lo, int64_t row_hi,
|
| 19 |
+
DeviceBuffer& buf) {
|
| 20 |
+
const auto* m = st_.get(name);
|
| 21 |
+
if (!m) { fprintf(stderr, "load_tensor_row_slice_: missing %s\n", name.c_str()); return false; }
|
| 22 |
+
if (m->shape.empty()) { fprintf(stderr, "%s: empty shape\n", name.c_str()); return false; }
|
| 23 |
+
int64_t D0 = m->shape[0];
|
| 24 |
+
if (row_hi > D0 || row_lo < 0 || row_hi <= row_lo) {
|
| 25 |
+
fprintf(stderr, "load_tensor_row_slice_: %s bad range [%ld,%ld) vs D0=%ld\n",
|
| 26 |
+
name.c_str(), row_lo, row_hi, D0);
|
| 27 |
+
return false;
|
| 28 |
+
}
|
| 29 |
+
size_t elem = sdtype_size(m->dtype);
|
| 30 |
+
size_t inner = 1;
|
| 31 |
+
for (size_t i = 1; i < m->shape.size(); i++) inner *= m->shape[i];
|
| 32 |
+
size_t row_bytes = inner * elem;
|
| 33 |
+
size_t slice_bytes = (row_hi - row_lo) * row_bytes;
|
| 34 |
+
|
| 35 |
+
const auto* host = (const char*)st_.data_ptr(*m);
|
| 36 |
+
buf.alloc(slice_bytes);
|
| 37 |
+
ACL_CHECK(aclrtMemcpy(buf.get(), slice_bytes,
|
| 38 |
+
host + row_lo * row_bytes, slice_bytes,
|
| 39 |
+
ACL_MEMCPY_HOST_TO_DEVICE));
|
| 40 |
+
return true;
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
bool DeviceWeightsLoader::load_tensor_col_slice_(const std::string& name,
|
| 44 |
+
int64_t col_lo, int64_t col_hi,
|
| 45 |
+
DeviceBuffer& buf) {
|
| 46 |
+
const auto* m = st_.get(name);
|
| 47 |
+
if (!m || m->shape.size() < 2) {
|
| 48 |
+
fprintf(stderr, "load_tensor_col_slice_: bad shape %s\n", name.c_str()); return false;
|
| 49 |
+
}
|
| 50 |
+
int64_t D0 = m->shape[0];
|
| 51 |
+
int64_t D1 = m->shape[1];
|
| 52 |
+
if (col_hi > D1 || col_lo < 0 || col_hi <= col_lo) {
|
| 53 |
+
fprintf(stderr, "load_tensor_col_slice_: bad range %ld-%ld D1=%ld\n",
|
| 54 |
+
col_lo, col_hi, D1); return false;
|
| 55 |
+
}
|
| 56 |
+
size_t elem = sdtype_size(m->dtype);
|
| 57 |
+
int64_t new_cols = col_hi - col_lo;
|
| 58 |
+
size_t slice_bytes = D0 * new_cols * elem;
|
| 59 |
+
buf.alloc(slice_bytes);
|
| 60 |
+
|
| 61 |
+
// Need to copy row-by-row since source has stride D1 but dest has stride new_cols.
|
| 62 |
+
const auto* host = (const char*)st_.data_ptr(*m);
|
| 63 |
+
std::vector<char> staging(slice_bytes);
|
| 64 |
+
size_t src_row = D1 * elem;
|
| 65 |
+
size_t dst_row = new_cols * elem;
|
| 66 |
+
size_t col_off = col_lo * elem;
|
| 67 |
+
for (int64_t r = 0; r < D0; r++) {
|
| 68 |
+
std::memcpy(staging.data() + r * dst_row, host + r * src_row + col_off, dst_row);
|
| 69 |
+
}
|
| 70 |
+
ACL_CHECK(aclrtMemcpy(buf.get(), slice_bytes, staging.data(), slice_bytes,
|
| 71 |
+
ACL_MEMCPY_HOST_TO_DEVICE));
|
| 72 |
+
return true;
|
| 73 |
+
}
|
| 74 |
+
|
| 75 |
+
bool DeviceWeightsLoader::load_shared(SharedWeights& out) {
|
| 76 |
+
if (!load_tensor_full_("model.embed_tokens.weight", out.embed_tokens)) return false;
|
| 77 |
+
if (!load_tensor_full_("lm_head.weight", out.lm_head)) return false;
|
| 78 |
+
if (!load_tensor_full_("model.norm.weight", out.final_norm)) return false;
|
| 79 |
+
return true;
|
| 80 |
+
}
|
| 81 |
+
|
| 82 |
+
bool DeviceWeightsLoader::load_moe(int L, aclrtStream stream, LayerMoEWeights& out) {
|
| 83 |
+
const int64_t E = cfg_.num_experts;
|
| 84 |
+
const int64_t D = cfg_.hidden_size;
|
| 85 |
+
const int64_t I_full = cfg_.moe_intermediate_size;
|
| 86 |
+
const int64_t I_rank = cfg_.i_per_rank;
|
| 87 |
+
const size_t elem = 2; // BF16
|
| 88 |
+
|
| 89 |
+
auto base = "model.layers." + std::to_string(L);
|
| 90 |
+
|
| 91 |
+
// 1. Router [E, D] — small, fully replicated
|
| 92 |
+
if (!load_tensor_full_(base + ".mlp.gate.weight", out.router)) return false;
|
| 93 |
+
|
| 94 |
+
// 2. MoE expert weights: need to stack 128 experts + TP slice + permute
|
| 95 |
+
// HF gate/up: each expert [I_full, D] → TP slice rows to [I_rank, D]
|
| 96 |
+
// HF down: each expert [D, I_full] → TP slice cols to [D, I_rank]
|
| 97 |
+
|
| 98 |
+
auto load_and_stack = [&](const std::string& proj_name,
|
| 99 |
+
bool is_down, DeviceBuffer& final_buf) -> bool {
|
| 100 |
+
// HF shape for gate/up: [I_full, D]; for down: [D, I_full]
|
| 101 |
+
// After TP slice: gate/up rows [I_rank, D]; down cols [D, I_rank]
|
| 102 |
+
// Stacked:
|
| 103 |
+
// gate/up: [E, I_rank, D] → permute to [E, D, I_rank]
|
| 104 |
+
// down: [E, D, I_rank] → permute to [E, I_rank, D]
|
| 105 |
+
|
| 106 |
+
int64_t K_in, N_out;
|
| 107 |
+
bool row_slice;
|
| 108 |
+
if (!is_down) {
|
| 109 |
+
K_in = I_rank; // HF first dim after row-slice
|
| 110 |
+
N_out = D;
|
| 111 |
+
row_slice = true;
|
| 112 |
+
} else {
|
| 113 |
+
K_in = D;
|
| 114 |
+
N_out = I_rank;
|
| 115 |
+
row_slice = false; // col slice
|
| 116 |
+
}
|
| 117 |
+
|
| 118 |
+
// Stage: stacked HF-layout [E, K_in, N_out] on device (before permute)
|
| 119 |
+
size_t elem_stack = K_in * N_out * elem;
|
| 120 |
+
DeviceBuffer stacked_hf(E * elem_stack);
|
| 121 |
+
|
| 122 |
+
// For each expert, load + TP slice + memcpy to stacked_hf[e]
|
| 123 |
+
// We use the existing row_slice/col_slice helpers on a per-expert basis.
|
| 124 |
+
DeviceBuffer tmp;
|
| 125 |
+
for (int e = 0; e < E; e++) {
|
| 126 |
+
std::string name = base + ".mlp.experts." + std::to_string(e) + "." + proj_name + ".weight";
|
| 127 |
+
if (row_slice) {
|
| 128 |
+
int64_t lo = cfg_.tp_rank * I_rank;
|
| 129 |
+
int64_t hi = lo + I_rank;
|
| 130 |
+
if (!load_tensor_row_slice_(name, lo, hi, tmp)) return false;
|
| 131 |
+
} else {
|
| 132 |
+
int64_t lo = cfg_.tp_rank * I_rank;
|
| 133 |
+
int64_t hi = lo + I_rank;
|
| 134 |
+
if (!load_tensor_col_slice_(name, lo, hi, tmp)) return false;
|
| 135 |
+
}
|
| 136 |
+
if (tmp.size != elem_stack) {
|
| 137 |
+
fprintf(stderr, "load_moe: expert %d %s slice size %zu != expected %zu\n",
|
| 138 |
+
e, name.c_str(), tmp.size, elem_stack);
|
| 139 |
+
return false;
|
| 140 |
+
}
|
| 141 |
+
// Synchronous D2D: tmp is about to be reallocated in the next iteration,
|
| 142 |
+
// so we cannot enqueue an async copy that would still reference it.
|
| 143 |
+
ACL_CHECK(aclrtMemcpy(
|
| 144 |
+
(char*)stacked_hf.get() + e * elem_stack, elem_stack,
|
| 145 |
+
tmp.get(), elem_stack,
|
| 146 |
+
ACL_MEMCPY_DEVICE_TO_DEVICE));
|
| 147 |
+
}
|
| 148 |
+
|
| 149 |
+
// Now permute stacked_hf [E, K_in, N_out] → final [E, N_out, K_in] row-major
|
| 150 |
+
// (swap last two dims)
|
| 151 |
+
final_buf.alloc(E * elem_stack);
|
| 152 |
+
const int64_t d0 = E, d1 = K_in, d2 = N_out;
|
| 153 |
+
// View stacked_hf with permuted strides pointing to same data:
|
| 154 |
+
// logical shape [E, N_out, K_in], strides [K_in*N_out, 1, N_out]
|
| 155 |
+
// (since physical is [E, K_in, N_out] row-major with strides [K_in*N_out, N_out, 1])
|
| 156 |
+
auto t_src = make_acl_tensor(stacked_hf.get(), ACL_BF16,
|
| 157 |
+
{d0, d2, d1}, // [E, N_out, K_in]
|
| 158 |
+
{d1 * d2, 1, d2});
|
| 159 |
+
auto t_dst = make_contig_tensor(final_buf.get(), ACL_BF16, {d0, d2, d1});
|
| 160 |
+
inplace_copy(stream, t_dst.get(), t_src.get());
|
| 161 |
+
// Must sync before stacked_hf goes out of scope — the inplace_copy is async and
|
| 162 |
+
// reads from stacked_hf's memory. If we return without syncing, DeviceBuffer's
|
| 163 |
+
// destructor frees stacked_hf while the permute kernel is still running, producing
|
| 164 |
+
// garbage in final_buf.
|
| 165 |
+
ACL_CHECK(aclrtSynchronizeStream(stream));
|
| 166 |
+
return true;
|
| 167 |
+
};
|
| 168 |
+
|
| 169 |
+
if (!load_and_stack("gate_proj", false, out.gate_exps)) return false;
|
| 170 |
+
if (!load_and_stack("up_proj", false, out.up_exps)) return false;
|
| 171 |
+
if (!load_and_stack("down_proj", true, out.down_exps)) return false;
|
| 172 |
+
|
| 173 |
+
return true;
|
| 174 |
+
}
|
| 175 |
+
|
| 176 |
+
bool DeviceWeightsLoader::load_attention(int L, LayerAttnWeights& out) {
|
| 177 |
+
auto base = "model.layers." + std::to_string(L);
|
| 178 |
+
|
| 179 |
+
if (!load_tensor_full_(base + ".input_layernorm.weight", out.input_layernorm)) return false;
|
| 180 |
+
if (!load_tensor_full_(base + ".post_attention_layernorm.weight", out.post_attention_layernorm)) return false;
|
| 181 |
+
if (!load_tensor_full_(base + ".self_attn.q_norm.weight", out.q_norm)) return false;
|
| 182 |
+
if (!load_tensor_full_(base + ".self_attn.k_norm.weight", out.k_norm)) return false;
|
| 183 |
+
|
| 184 |
+
const int64_t head_dim = cfg_.head_dim;
|
| 185 |
+
const int64_t q_full = cfg_.num_attention_heads * head_dim; // 64 * 128 = 8192
|
| 186 |
+
|
| 187 |
+
// q_proj: [q_full, D], shard rows by head. Each rank gets n_heads_per_rank heads.
|
| 188 |
+
int64_t q_rows_per_rank = cfg_.n_heads_per_rank * head_dim;
|
| 189 |
+
int64_t q_row_lo = cfg_.tp_rank * q_rows_per_rank;
|
| 190 |
+
int64_t q_row_hi = q_row_lo + q_rows_per_rank;
|
| 191 |
+
if (!load_tensor_row_slice_(base + ".self_attn.q_proj.weight",
|
| 192 |
+
q_row_lo, q_row_hi, out.q_proj)) return false;
|
| 193 |
+
|
| 194 |
+
// k_proj, v_proj: HF shape [num_kv * head_dim, D].
|
| 195 |
+
// Case A (tp <= n_kv): split rows across ranks, each rank gets n_kv/tp KV heads.
|
| 196 |
+
// Case B (tp > n_kv): each rank gets exactly ONE KV head; group of (tp/n_kv) ranks share it.
|
| 197 |
+
// kv_head_idx = tp_rank / (tp_size / n_kv)
|
| 198 |
+
if (cfg_.tp_size <= cfg_.num_key_value_heads) {
|
| 199 |
+
int64_t kv_rows_per_rank = cfg_.n_kv_heads_per_rank * head_dim;
|
| 200 |
+
int64_t kv_row_lo = cfg_.tp_rank * kv_rows_per_rank;
|
| 201 |
+
int64_t kv_row_hi = kv_row_lo + kv_rows_per_rank;
|
| 202 |
+
if (!load_tensor_row_slice_(base + ".self_attn.k_proj.weight", kv_row_lo, kv_row_hi, out.k_proj)) return false;
|
| 203 |
+
if (!load_tensor_row_slice_(base + ".self_attn.v_proj.weight", kv_row_lo, kv_row_hi, out.v_proj)) return false;
|
| 204 |
+
} else {
|
| 205 |
+
// GQA replicated-group mode: 1 KV head per rank, selected by group.
|
| 206 |
+
int64_t ranks_per_kv = cfg_.tp_size / cfg_.num_key_value_heads;
|
| 207 |
+
int64_t kv_head_idx = cfg_.tp_rank / ranks_per_kv;
|
| 208 |
+
int64_t kv_row_lo = kv_head_idx * head_dim;
|
| 209 |
+
int64_t kv_row_hi = kv_row_lo + head_dim;
|
| 210 |
+
if (!load_tensor_row_slice_(base + ".self_attn.k_proj.weight", kv_row_lo, kv_row_hi, out.k_proj)) return false;
|
| 211 |
+
if (!load_tensor_row_slice_(base + ".self_attn.v_proj.weight", kv_row_lo, kv_row_hi, out.v_proj)) return false;
|
| 212 |
+
}
|
| 213 |
+
|
| 214 |
+
// o_proj: [D, q_full], row-parallel → shard cols (input dim) by head.
|
| 215 |
+
int64_t o_col_lo = q_row_lo; // same slicing as q rows
|
| 216 |
+
int64_t o_col_hi = q_row_hi;
|
| 217 |
+
if (!load_tensor_col_slice_(base + ".self_attn.o_proj.weight",
|
| 218 |
+
o_col_lo, o_col_hi, out.o_proj)) return false;
|
| 219 |
+
|
| 220 |
+
return true;
|
| 221 |
+
}
|
|
@@ -0,0 +1,816 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// main_cli.cpp — qwen3-moe-aclnn entry point.
|
| 2 |
+
//
|
| 3 |
+
// Usage:
|
| 4 |
+
// qwen3-moe-aclnn --model-dir <path> --prompt "<text>" --n-predict <N>
|
| 5 |
+
// [--tp-size 1|16] [--vocab <path>] [--max-seq N] [--num-layers N]
|
| 6 |
+
// [--chat] [--temperature 0.7] [--top-k 20] [--top-p 0.8] [--seed N]
|
| 7 |
+
// [--no-stream]
|
| 8 |
+
//
|
| 9 |
+
// At TP>1 each rank is a separate process (env TP_RANK=<i>, TP_SIZE=<n>) launched by
|
| 10 |
+
// scripts/tp_launch.sh. Only rank 0 prints text output.
|
| 11 |
+
#include "runner.h"
|
| 12 |
+
#include "tokenizer.h"
|
| 13 |
+
|
| 14 |
+
// Escape hatch for HCCL broadcast from within CLI (defined in runner.cpp)
|
| 15 |
+
HcclCtx* runner_hccl_ctx_shim(Runner& r);
|
| 16 |
+
|
| 17 |
+
#include <algorithm>
|
| 18 |
+
#include <chrono>
|
| 19 |
+
#include <cmath>
|
| 20 |
+
#include <cstdio>
|
| 21 |
+
#include <cstdlib>
|
| 22 |
+
#include <cstring>
|
| 23 |
+
#include <iostream>
|
| 24 |
+
#include <random>
|
| 25 |
+
#include <string>
|
| 26 |
+
#include <vector>
|
| 27 |
+
|
| 28 |
+
static float bf16_to_float(uint16_t x) {
|
| 29 |
+
uint32_t u = (uint32_t)x << 16; float f; std::memcpy(&f, &u, 4); return f;
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
// Truncate a string to the last complete UTF-8 character boundary. If the last 1-3 bytes
|
| 33 |
+
// form an incomplete multi-byte sequence (e.g., assistant response cut mid-codepoint at
|
| 34 |
+
// n_predict limit), drop them so the JSON encoder downstream sees only valid UTF-8.
|
| 35 |
+
static std::string utf8_trim_incomplete(const std::string& s) {
|
| 36 |
+
if (s.empty()) return s;
|
| 37 |
+
size_t n = s.size();
|
| 38 |
+
// Walk back up to 4 bytes looking for the start of a UTF-8 sequence.
|
| 39 |
+
for (size_t back = 0; back < 4 && back < n; back++) {
|
| 40 |
+
size_t i = n - 1 - back;
|
| 41 |
+
unsigned char c = (unsigned char)s[i];
|
| 42 |
+
if ((c & 0x80) == 0) { return s; } // ASCII: already complete
|
| 43 |
+
if ((c & 0xC0) == 0x80) { continue; } // continuation byte: keep going
|
| 44 |
+
// Start byte: 110xxxxx (2-byte), 1110xxxx (3-byte), 11110xxx (4-byte)
|
| 45 |
+
size_t need = 0;
|
| 46 |
+
if ((c & 0xE0) == 0xC0) need = 2;
|
| 47 |
+
else if ((c & 0xF0) == 0xE0) need = 3;
|
| 48 |
+
else if ((c & 0xF8) == 0xF0) need = 4;
|
| 49 |
+
else return s.substr(0, i); // invalid start — drop
|
| 50 |
+
size_t have = back + 1;
|
| 51 |
+
return (have >= need) ? s : s.substr(0, i); // trim incomplete trailing sequence
|
| 52 |
+
}
|
| 53 |
+
// Should not reach here; return as-is.
|
| 54 |
+
return s;
|
| 55 |
+
}
|
| 56 |
+
|
| 57 |
+
struct Args {
|
| 58 |
+
std::string model_dir;
|
| 59 |
+
std::string prompt = "The capital of France is";
|
| 60 |
+
std::string vocab_path = "tokenizer_data/vocab.bin";
|
| 61 |
+
int n_predict = 100;
|
| 62 |
+
int tp_size = 1;
|
| 63 |
+
int tp_rank = 0;
|
| 64 |
+
int num_layers = 0; // 0 = auto
|
| 65 |
+
int max_seq = 512;
|
| 66 |
+
int device_id = 0;
|
| 67 |
+
bool chat_template = false;
|
| 68 |
+
bool stream = true;
|
| 69 |
+
bool interactive = false;
|
| 70 |
+
bool reset_each_turn = false; // if true, REPL clears KV cache between turns (stateless)
|
| 71 |
+
std::string system_prompt; // optional system role for chat mode
|
| 72 |
+
std::string prompt_file; // read prompt from file (avoids shell escaping)
|
| 73 |
+
bool pld_enabled = false; // prompt lookup decoding
|
| 74 |
+
int pld_k = 10; // bench_pld_k.sh: K=10 median 105 t/s (3/3 runs 100+), K=8 was 35
|
| 75 |
+
int pld_ngram = 1; // n-gram match size — 1 with multi-level fallback best
|
| 76 |
+
bool pld_adaptive = false; // fixed K=10 is simpler and mean-optimal; adaptive --pld-adaptive
|
| 77 |
+
int pld_min_hist = 20; // skip PLD until history >= this (avoid early-token false matches)
|
| 78 |
+
// PLD degeneration guard (on by default): prevents PLD from amplifying repetition loops.
|
| 79 |
+
bool pld_guard = true; // --pld-no-guard disables
|
| 80 |
+
int pld_guard_distinct = 3; // reject draft if distinct tokens < this (≥K/3 heuristic)
|
| 81 |
+
int pld_guard_tail = 6; // reject if draft[0] matches all last N hist tokens
|
| 82 |
+
int pld_loop_warn = 8; // warn once when N consecutive identical tokens emitted
|
| 83 |
+
float temperature = 0.0f; // 0 = greedy
|
| 84 |
+
int top_k = 0; // 0 = disabled
|
| 85 |
+
float top_p = 1.0f; // 1.0 = disabled
|
| 86 |
+
uint64_t seed = 0; // 0 = use time
|
| 87 |
+
// Qwen3 EOS tokens (from generation_config.json)
|
| 88 |
+
std::vector<int> eos_ids = {151645, 151643};
|
| 89 |
+
};
|
| 90 |
+
|
| 91 |
+
static bool parse_args(int argc, char** argv, Args& a) {
|
| 92 |
+
for (int i = 1; i < argc; i++) {
|
| 93 |
+
std::string s = argv[i];
|
| 94 |
+
auto next = [&](const char* f)->const char* {
|
| 95 |
+
if (i + 1 >= argc) { fprintf(stderr, "missing value for %s\n", f); return nullptr; }
|
| 96 |
+
return argv[++i];
|
| 97 |
+
};
|
| 98 |
+
if (s == "--model-dir") { auto v = next(s.c_str()); if (!v) return false; a.model_dir = v; }
|
| 99 |
+
else if (s == "--prompt") { auto v = next(s.c_str()); if (!v) return false; a.prompt = v; }
|
| 100 |
+
else if (s == "--vocab") { auto v = next(s.c_str()); if (!v) return false; a.vocab_path = v; }
|
| 101 |
+
else if (s == "--n-predict") { auto v = next(s.c_str()); if (!v) return false; a.n_predict = std::atoi(v); }
|
| 102 |
+
else if (s == "--tp-size") { auto v = next(s.c_str()); if (!v) return false; a.tp_size = std::atoi(v); }
|
| 103 |
+
else if (s == "--num-layers") { auto v = next(s.c_str()); if (!v) return false; a.num_layers = std::atoi(v); }
|
| 104 |
+
else if (s == "--max-seq") { auto v = next(s.c_str()); if (!v) return false; a.max_seq = std::atoi(v); }
|
| 105 |
+
else if (s == "--device") { auto v = next(s.c_str()); if (!v) return false; a.device_id = std::atoi(v); }
|
| 106 |
+
else if (s == "--temperature") { auto v = next(s.c_str()); if (!v) return false; a.temperature = (float)std::atof(v); }
|
| 107 |
+
else if (s == "--top-k") { auto v = next(s.c_str()); if (!v) return false; a.top_k = std::atoi(v); }
|
| 108 |
+
else if (s == "--top-p") { auto v = next(s.c_str()); if (!v) return false; a.top_p = (float)std::atof(v); }
|
| 109 |
+
else if (s == "--seed") { auto v = next(s.c_str()); if (!v) return false; a.seed = (uint64_t)std::atoll(v); }
|
| 110 |
+
else if (s == "--chat") { a.chat_template = true; }
|
| 111 |
+
else if (s == "--no-stream") { a.stream = false; }
|
| 112 |
+
else if (s == "--interactive" || s == "-i") { a.interactive = true; }
|
| 113 |
+
else if (s == "--reset") { a.reset_each_turn = true; }
|
| 114 |
+
else if (s == "--system") { auto v = next(s.c_str()); if (!v) return false; a.system_prompt = v; }
|
| 115 |
+
else if (s == "--prompt-file") { auto v = next(s.c_str()); if (!v) return false; a.prompt_file = v; }
|
| 116 |
+
else if (s == "--pld") { a.pld_enabled = true; }
|
| 117 |
+
else if (s == "--pld-k") { auto v = next(s.c_str()); if (!v) return false; a.pld_k = std::atoi(v); }
|
| 118 |
+
else if (s == "--pld-ngram") { auto v = next(s.c_str()); if (!v) return false; a.pld_ngram = std::atoi(v); }
|
| 119 |
+
else if (s == "--pld-adaptive"){ a.pld_adaptive = true; }
|
| 120 |
+
else if (s == "--pld-fixed-k") { a.pld_adaptive = false; } // opt out of adaptive
|
| 121 |
+
else if (s == "--pld-min-hist"){ auto v = next(s.c_str()); if (!v) return false; a.pld_min_hist = std::atoi(v); }
|
| 122 |
+
else if (s == "--pld-no-guard"){ a.pld_guard = false; }
|
| 123 |
+
else if (s == "--pld-guard-distinct"){ auto v = next(s.c_str()); if (!v) return false; a.pld_guard_distinct = std::atoi(v); }
|
| 124 |
+
else if (s == "--pld-guard-tail"){ auto v = next(s.c_str()); if (!v) return false; a.pld_guard_tail = std::atoi(v); }
|
| 125 |
+
else if (s == "--pld-loop-warn"){ auto v = next(s.c_str()); if (!v) return false; a.pld_loop_warn = std::atoi(v); }
|
| 126 |
+
else if (s == "--help" || s == "-h") {
|
| 127 |
+
printf("Usage: %s --model-dir <path> [options]\n", argv[0]);
|
| 128 |
+
printf(" --prompt \"text\" prompt text (default: \"%s\")\n", a.prompt.c_str());
|
| 129 |
+
printf(" --prompt-file FILE read prompt from file (overrides --prompt)\n");
|
| 130 |
+
printf(" --n-predict N max tokens to generate (default: %d)\n", a.n_predict);
|
| 131 |
+
printf(" --tp-size N tensor parallelism (default: 1; or TP_SIZE env)\n");
|
| 132 |
+
printf(" --num-layers N limit layers, testing only (default: all)\n");
|
| 133 |
+
printf(" --max-seq N KV cache + context cap (default: %d)\n", a.max_seq);
|
| 134 |
+
printf(" --chat apply Qwen3 chat template\n");
|
| 135 |
+
printf(" --system \"text\" system role for chat\n");
|
| 136 |
+
printf(" --temperature F 0 = greedy; typical 0.7\n");
|
| 137 |
+
printf(" --top-k N 0 = disabled\n");
|
| 138 |
+
printf(" --top-p F 1.0 = disabled; typical 0.8\n");
|
| 139 |
+
printf(" --seed N 0 = time-based (default)\n");
|
| 140 |
+
printf(" --no-stream batch-print final text\n");
|
| 141 |
+
printf(" -i, --interactive REPL (multi-turn memory when --chat)\n");
|
| 142 |
+
printf(" --reset force stateless REPL (reset each turn)\n");
|
| 143 |
+
printf(" --pld enable Prompt Lookup Decoding (greedy only)\n");
|
| 144 |
+
printf(" --pld-k N draft window size (default: 4)\n");
|
| 145 |
+
printf(" --pld-ngram N match n-gram size (default: 2; multi-level fallback)\n");
|
| 146 |
+
printf(" --pld-adaptive adjust K based on recent accept rate\n");
|
| 147 |
+
printf(" --pld-min-hist N skip PLD until history >= N tokens (default: 20)\n");
|
| 148 |
+
printf(" --pld-no-guard disable degeneration guard (dangerous: can amplify loops)\n");
|
| 149 |
+
printf(" --pld-guard-distinct N reject draft with distinct tokens < N (default: 3)\n");
|
| 150 |
+
printf(" --pld-guard-tail N reject draft if draft[0] matches all last N hist (default: 6)\n");
|
| 151 |
+
printf(" --pld-loop-warn N warn once on N consecutive identical emitted tokens (default: 8)\n");
|
| 152 |
+
return false;
|
| 153 |
+
}
|
| 154 |
+
else { fprintf(stderr, "unknown arg: %s\n", s.c_str()); return false; }
|
| 155 |
+
}
|
| 156 |
+
if (a.model_dir.empty()) { fprintf(stderr, "--model-dir required\n"); return false; }
|
| 157 |
+
if (const char* r = std::getenv("TP_RANK")) a.tp_rank = std::atoi(r);
|
| 158 |
+
if (const char* s = std::getenv("TP_SIZE")) a.tp_size = std::atoi(s);
|
| 159 |
+
return true;
|
| 160 |
+
}
|
| 161 |
+
|
| 162 |
+
// Sample next token from logits. temperature=0 → greedy argmax. Otherwise top-k / top-p.
|
| 163 |
+
static int sample_token(const std::vector<uint16_t>& logits_bf16, int64_t V,
|
| 164 |
+
float temperature, int top_k, float top_p, std::mt19937& rng) {
|
| 165 |
+
if (temperature <= 0.0f) {
|
| 166 |
+
int best = 0;
|
| 167 |
+
float bv = bf16_to_float(logits_bf16[0]);
|
| 168 |
+
for (int64_t i = 1; i < V; i++) {
|
| 169 |
+
float v = bf16_to_float(logits_bf16[i]);
|
| 170 |
+
if (v > bv) { bv = v; best = (int)i; }
|
| 171 |
+
}
|
| 172 |
+
return best;
|
| 173 |
+
}
|
| 174 |
+
|
| 175 |
+
// Build (logit, id) list as float
|
| 176 |
+
std::vector<std::pair<float, int>> scored;
|
| 177 |
+
scored.reserve(V);
|
| 178 |
+
for (int64_t i = 0; i < V; i++) {
|
| 179 |
+
scored.emplace_back(bf16_to_float(logits_bf16[i]) / temperature, (int)i);
|
| 180 |
+
}
|
| 181 |
+
|
| 182 |
+
// Top-k: keep highest k entries (partial sort)
|
| 183 |
+
if (top_k > 0 && top_k < (int)scored.size()) {
|
| 184 |
+
std::nth_element(scored.begin(), scored.begin() + top_k, scored.end(),
|
| 185 |
+
[](const auto& a, const auto& b){ return a.first > b.first; });
|
| 186 |
+
scored.resize(top_k);
|
| 187 |
+
}
|
| 188 |
+
|
| 189 |
+
// Sort descending for top-p
|
| 190 |
+
std::sort(scored.begin(), scored.end(),
|
| 191 |
+
[](const auto& a, const auto& b){ return a.first > b.first; });
|
| 192 |
+
|
| 193 |
+
// Softmax (numerically stable)
|
| 194 |
+
float maxv = scored[0].first;
|
| 195 |
+
double sum = 0;
|
| 196 |
+
for (auto& p : scored) { p.first = std::exp(p.first - maxv); sum += p.first; }
|
| 197 |
+
for (auto& p : scored) p.first /= (float)sum;
|
| 198 |
+
|
| 199 |
+
// Top-p nucleus
|
| 200 |
+
if (top_p > 0.0f && top_p < 1.0f) {
|
| 201 |
+
double cum = 0;
|
| 202 |
+
size_t cutoff = scored.size();
|
| 203 |
+
for (size_t i = 0; i < scored.size(); i++) {
|
| 204 |
+
cum += scored[i].first;
|
| 205 |
+
if (cum >= top_p) { cutoff = i + 1; break; }
|
| 206 |
+
}
|
| 207 |
+
scored.resize(cutoff);
|
| 208 |
+
// re-normalize
|
| 209 |
+
double s = 0; for (auto& p : scored) s += p.first;
|
| 210 |
+
for (auto& p : scored) p.first /= (float)s;
|
| 211 |
+
}
|
| 212 |
+
|
| 213 |
+
// Sample
|
| 214 |
+
std::uniform_real_distribution<float> U(0.0f, 1.0f);
|
| 215 |
+
float r = U(rng), acc = 0.0f;
|
| 216 |
+
for (auto& p : scored) {
|
| 217 |
+
acc += p.first;
|
| 218 |
+
if (r <= acc) return p.second;
|
| 219 |
+
}
|
| 220 |
+
return scored.back().second;
|
| 221 |
+
}
|
| 222 |
+
|
| 223 |
+
// Broadcast a prompt's token_ids from rank 0 to all ranks. For TP>1 the non-master ranks need
|
| 224 |
+
// the tokens before prefill. We use HCCL broadcast: rank 0 provides the count, then the ids.
|
| 225 |
+
// Uses a pre-allocated device buffer (must be large enough for max_seq tokens).
|
| 226 |
+
static bool broadcast_token_ids(Runner& runner, std::vector<int32_t>& ids,
|
| 227 |
+
int64_t max_seq, bool is_master) {
|
| 228 |
+
const ModelConfig& cfg = runner.cfg();
|
| 229 |
+
if (cfg.tp_size <= 1) return true;
|
| 230 |
+
|
| 231 |
+
// Step 1: broadcast count (as int32 on device)
|
| 232 |
+
DeviceBuffer cnt_dev(4);
|
| 233 |
+
int32_t cnt = is_master ? (int32_t)ids.size() : 0;
|
| 234 |
+
ACL_CHECK(aclrtMemcpy(cnt_dev.get(), 4, &cnt, 4, ACL_MEMCPY_HOST_TO_DEVICE));
|
| 235 |
+
// Access Runner's HCCL context via stream (exposed) and rely on the fact that ctx.comm is owned.
|
| 236 |
+
// Since hccl_broadcast needs HcclCtx, we need access. Cheapest: friend access via a shim member.
|
| 237 |
+
// For now, Runner has a stream() accessor; HCCL ctx is private. We'll accept that and broadcast
|
| 238 |
+
// via a direct call on the comm — but ctx is hidden. Workaround: expose hccl_ctx() on Runner.
|
| 239 |
+
// ... (see Runner::hccl_ctx() accessor added below)
|
| 240 |
+
extern HcclCtx* runner_hccl_ctx_shim(Runner& r); // forward from runner.cpp
|
| 241 |
+
HcclCtx* ctx = runner_hccl_ctx_shim(runner);
|
| 242 |
+
if (!ctx) return false;
|
| 243 |
+
|
| 244 |
+
if (!hccl_broadcast(*ctx, cnt_dev.get(), 1, HCCL_DATA_TYPE_INT32, 0, runner.stream())) return false;
|
| 245 |
+
ACL_CHECK(aclrtSynchronizeStream(runner.stream()));
|
| 246 |
+
ACL_CHECK(aclrtMemcpy(&cnt, 4, cnt_dev.get(), 4, ACL_MEMCPY_DEVICE_TO_HOST));
|
| 247 |
+
if (cnt <= 0 || cnt > (int32_t)max_seq) {
|
| 248 |
+
fprintf(stderr, "[rank %d] broadcast: bad count %d\n", cfg.tp_rank, cnt);
|
| 249 |
+
return false;
|
| 250 |
+
}
|
| 251 |
+
|
| 252 |
+
// Step 2: broadcast the id buffer
|
| 253 |
+
DeviceBuffer ids_dev(cnt * 4);
|
| 254 |
+
if (is_master) {
|
| 255 |
+
ACL_CHECK(aclrtMemcpy(ids_dev.get(), cnt*4, ids.data(), cnt*4, ACL_MEMCPY_HOST_TO_DEVICE));
|
| 256 |
+
}
|
| 257 |
+
if (!hccl_broadcast(*ctx, ids_dev.get(), cnt, HCCL_DATA_TYPE_INT32, 0, runner.stream())) return false;
|
| 258 |
+
ACL_CHECK(aclrtSynchronizeStream(runner.stream()));
|
| 259 |
+
if (!is_master) {
|
| 260 |
+
ids.resize(cnt);
|
| 261 |
+
ACL_CHECK(aclrtMemcpy(ids.data(), cnt*4, ids_dev.get(), cnt*4, ACL_MEMCPY_DEVICE_TO_HOST));
|
| 262 |
+
}
|
| 263 |
+
return true;
|
| 264 |
+
}
|
| 265 |
+
|
| 266 |
+
// Run one generation turn. Assumes KV cache is reset. Returns perf summary.
|
| 267 |
+
struct TurnStats {
|
| 268 |
+
double prefill_ms = 0; double decode_ms = 0;
|
| 269 |
+
int n_prompt = 0; int decoded = 0; bool hit_eos = false;
|
| 270 |
+
};
|
| 271 |
+
static TurnStats run_turn(Runner& runner, Tokenizer& tokenizer, const Args& args,
|
| 272 |
+
const std::string& prompt, std::mt19937& rng, bool is_master) {
|
| 273 |
+
TurnStats st;
|
| 274 |
+
|
| 275 |
+
// --- Tokenize (on master; broadcast for TP>1) ---
|
| 276 |
+
std::vector<int32_t> input_ids;
|
| 277 |
+
if (is_master) {
|
| 278 |
+
auto raw = tokenizer.encode_via_python(args.model_dir, prompt, args.chat_template);
|
| 279 |
+
if (raw.empty()) return st;
|
| 280 |
+
input_ids.reserve(raw.size());
|
| 281 |
+
for (int v : raw) input_ids.push_back((int32_t)v);
|
| 282 |
+
}
|
| 283 |
+
if (args.tp_size > 1) {
|
| 284 |
+
if (!broadcast_token_ids(runner, input_ids, args.max_seq, is_master)) return st;
|
| 285 |
+
}
|
| 286 |
+
if (input_ids.empty()) return st;
|
| 287 |
+
|
| 288 |
+
const int64_t V = runner.cfg().vocab_size;
|
| 289 |
+
std::vector<uint16_t> logits_h(V);
|
| 290 |
+
auto load_logits = [&](DeviceBuffer& buf) {
|
| 291 |
+
ACL_CHECK(aclrtMemcpy(logits_h.data(), V*2, buf.get(), V*2, ACL_MEMCPY_DEVICE_TO_HOST));
|
| 292 |
+
};
|
| 293 |
+
auto is_eos = [&](int id) {
|
| 294 |
+
for (int e : args.eos_ids) if (id == e) return true;
|
| 295 |
+
return false;
|
| 296 |
+
};
|
| 297 |
+
|
| 298 |
+
// --- Prefill ---
|
| 299 |
+
st.n_prompt = (int)input_ids.size();
|
| 300 |
+
auto t0 = std::chrono::steady_clock::now();
|
| 301 |
+
DeviceBuffer logits;
|
| 302 |
+
if (!runner.prefill(input_ids.data(), (int64_t)input_ids.size(), logits)) return st;
|
| 303 |
+
auto t1 = std::chrono::steady_clock::now();
|
| 304 |
+
st.prefill_ms = std::chrono::duration<double, std::milli>(t1 - t0).count();
|
| 305 |
+
|
| 306 |
+
load_logits(logits);
|
| 307 |
+
int next_id = sample_token(logits_h, V, args.temperature, args.top_k, args.top_p, rng);
|
| 308 |
+
|
| 309 |
+
if (is_master && args.stream) {
|
| 310 |
+
if (!args.chat_template) printf("%s", prompt.c_str());
|
| 311 |
+
printf("%s", tokenizer.decode(next_id).c_str());
|
| 312 |
+
fflush(stdout);
|
| 313 |
+
}
|
| 314 |
+
|
| 315 |
+
std::vector<int> generated = { next_id };
|
| 316 |
+
st.hit_eos = is_eos(next_id);
|
| 317 |
+
|
| 318 |
+
// All tokens (prompt + generated) for PLD n-gram lookup. Non-master ranks still need to
|
| 319 |
+
// track consistent length for HCCL broadcast of draft proposals.
|
| 320 |
+
std::vector<int32_t> hist;
|
| 321 |
+
hist.reserve(input_ids.size() + args.n_predict + 16);
|
| 322 |
+
for (auto x : input_ids) hist.push_back(x);
|
| 323 |
+
hist.push_back(next_id);
|
| 324 |
+
|
| 325 |
+
// PLD n-gram lookup: search for suffix match ending at end-of-hist; return K tokens following match.
|
| 326 |
+
// Longer matches = more reliable drafts. Multi-level: try n, fall back to smaller n if no match.
|
| 327 |
+
auto lookup_one = [&](int ngram, int K) -> std::vector<int32_t> {
|
| 328 |
+
int hs = (int)hist.size();
|
| 329 |
+
if (hs < ngram + 1 || K <= 0) return {};
|
| 330 |
+
for (int start = hs - ngram - 1; start >= 0; start--) {
|
| 331 |
+
bool match = true;
|
| 332 |
+
for (int k = 0; k < ngram; k++) {
|
| 333 |
+
if (hist[start + k] != hist[hs - ngram + k]) { match = false; break; }
|
| 334 |
+
}
|
| 335 |
+
if (match) {
|
| 336 |
+
int after = start + ngram;
|
| 337 |
+
std::vector<int32_t> d;
|
| 338 |
+
for (int k = 0; k < K && after + k < hs; k++) {
|
| 339 |
+
d.push_back(hist[after + k]);
|
| 340 |
+
if (is_eos(hist[after + k])) break;
|
| 341 |
+
}
|
| 342 |
+
if (!d.empty()) return d;
|
| 343 |
+
}
|
| 344 |
+
}
|
| 345 |
+
return {};
|
| 346 |
+
};
|
| 347 |
+
// Multi-level: try configured n first, then n-1, then n-2 (down to 1).
|
| 348 |
+
auto lookup_draft = [&](int ngram, int K) -> std::vector<int32_t> {
|
| 349 |
+
for (int n = ngram; n >= 1; n--) {
|
| 350 |
+
auto d = lookup_one(n, K);
|
| 351 |
+
if (!d.empty()) return d;
|
| 352 |
+
}
|
| 353 |
+
return {};
|
| 354 |
+
};
|
| 355 |
+
|
| 356 |
+
// Degeneration guard: classify a draft as repetition-induced so we can fall back to single
|
| 357 |
+
// decode (and avoid PLD amplifying model's own repetition loop into a runaway "W W W …" mess).
|
| 358 |
+
// Returns nullptr if draft is OK, else a short reason string for stats.
|
| 359 |
+
auto draft_degenerate = [&](const std::vector<int32_t>& d) -> const char* {
|
| 360 |
+
if (!args.pld_guard || d.empty()) return nullptr;
|
| 361 |
+
// (1) distinct-token count: a draft of K tokens with < args.pld_guard_distinct distinct
|
| 362 |
+
// values means n-gram is echoing a loop. Only apply when draft is long enough.
|
| 363 |
+
if ((int)d.size() >= 3) {
|
| 364 |
+
int distinct = 0;
|
| 365 |
+
for (int i = 0; i < (int)d.size(); i++) {
|
| 366 |
+
bool seen = false;
|
| 367 |
+
for (int j = 0; j < i; j++) { if (d[j] == d[i]) { seen = true; break; } }
|
| 368 |
+
if (!seen) distinct++;
|
| 369 |
+
}
|
| 370 |
+
if (distinct < args.pld_guard_distinct) return "low-distinct";
|
| 371 |
+
}
|
| 372 |
+
// (2) tail echo: if the last N hist tokens are all equal to draft[0], the model is already
|
| 373 |
+
// in a short loop — accepting the draft will just confirm the loop at batch speed.
|
| 374 |
+
int tail_n = std::min(args.pld_guard_tail, (int)hist.size());
|
| 375 |
+
if (tail_n >= 3) {
|
| 376 |
+
int matches = 0;
|
| 377 |
+
for (int i = (int)hist.size() - tail_n; i < (int)hist.size(); i++) {
|
| 378 |
+
if (hist[i] == d[0]) matches++;
|
| 379 |
+
}
|
| 380 |
+
if (matches == tail_n) return "tail-echo";
|
| 381 |
+
}
|
| 382 |
+
return nullptr;
|
| 383 |
+
};
|
| 384 |
+
|
| 385 |
+
// --- Decode loop ---
|
| 386 |
+
auto t2 = std::chrono::steady_clock::now();
|
| 387 |
+
int pld_verifies = 0, pld_accepted = 0;
|
| 388 |
+
int pld_rej_lowdist = 0, pld_rej_tailecho = 0; // guard rejection counters
|
| 389 |
+
bool loop_warned = false; // warn-once state
|
| 390 |
+
|
| 391 |
+
// Adaptive K state: recent accept counts for moving-average decisions
|
| 392 |
+
const int ADAPT_WINDOW = 8;
|
| 393 |
+
std::vector<int> recent_accepts;
|
| 394 |
+
int current_k = args.pld_k;
|
| 395 |
+
bool pld_disabled_adapt = false; // set true when recent accept rate is too low to benefit
|
| 396 |
+
|
| 397 |
+
while (st.decoded < args.n_predict - 1 && !st.hit_eos) {
|
| 398 |
+
// Adaptive K: scale K with recent accept rate.
|
| 399 |
+
// No auto-disable: since S=K+1 forward ≈ S=1 forward (latency-bound), even accept=0.1
|
| 400 |
+
// still nets slightly positive — PLD doesn't "hurt" as long as ngram lookup is cheap.
|
| 401 |
+
if (args.pld_adaptive && (int)recent_accepts.size() >= ADAPT_WINDOW) {
|
| 402 |
+
double avg = 0;
|
| 403 |
+
for (int a : recent_accepts) avg += a;
|
| 404 |
+
avg /= recent_accepts.size();
|
| 405 |
+
// Aim: K = 2*avg + 4 (generous window to catch upswings). Clamp [4, 12].
|
| 406 |
+
current_k = std::max(4, std::min(12, (int)std::round(2.0 * avg + 4.0)));
|
| 407 |
+
}
|
| 408 |
+
|
| 409 |
+
// Try PLD speculation path — skip until enough history accumulated
|
| 410 |
+
std::vector<int32_t> draft;
|
| 411 |
+
if (args.pld_enabled && (int)hist.size() >= args.pld_min_hist && is_master) {
|
| 412 |
+
draft = lookup_draft(args.pld_ngram, current_k);
|
| 413 |
+
// Degeneration guard: if draft looks like repetition-loop echo, drop it so this
|
| 414 |
+
// iteration falls through to normal single decode. This does NOT stop a loop the model
|
| 415 |
+
// is already in (greedy is deterministic), but it prevents PLD from running the loop
|
| 416 |
+
// at batch speed while masquerading as a speedup.
|
| 417 |
+
if (!draft.empty()) {
|
| 418 |
+
const char* reason = draft_degenerate(draft);
|
| 419 |
+
if (reason) {
|
| 420 |
+
if (reason[0] == 'l') pld_rej_lowdist++;
|
| 421 |
+
else pld_rej_tailecho++;
|
| 422 |
+
draft.clear();
|
| 423 |
+
}
|
| 424 |
+
}
|
| 425 |
+
}
|
| 426 |
+
// For TP>1, broadcast draft across ranks. Only broadcast if master has a non-empty draft;
|
| 427 |
+
// otherwise all ranks take the no-draft path (normal decode).
|
| 428 |
+
bool has_draft = is_master ? !draft.empty() : false;
|
| 429 |
+
// Broadcast the "has_draft" flag (using a 1-element count: 1 = yes, 0 = no)
|
| 430 |
+
if (args.tp_size > 1) {
|
| 431 |
+
extern HcclCtx* runner_hccl_ctx_shim(Runner&);
|
| 432 |
+
HcclCtx* ctx = runner_hccl_ctx_shim(runner);
|
| 433 |
+
DeviceBuffer flag(4);
|
| 434 |
+
int32_t f = has_draft ? 1 : 0;
|
| 435 |
+
ACL_CHECK(aclrtMemcpy(flag.get(), 4, &f, 4, ACL_MEMCPY_HOST_TO_DEVICE));
|
| 436 |
+
hccl_broadcast(*ctx, flag.get(), 1, HCCL_DATA_TYPE_INT32, 0, runner.stream());
|
| 437 |
+
ACL_CHECK(aclrtSynchronizeStream(runner.stream()));
|
| 438 |
+
ACL_CHECK(aclrtMemcpy(&f, 4, flag.get(), 4, ACL_MEMCPY_DEVICE_TO_HOST));
|
| 439 |
+
has_draft = (f != 0);
|
| 440 |
+
if (has_draft) {
|
| 441 |
+
std::vector<int32_t> d = draft;
|
| 442 |
+
broadcast_token_ids(runner, d, args.max_seq, is_master);
|
| 443 |
+
draft = d;
|
| 444 |
+
} else {
|
| 445 |
+
draft.clear();
|
| 446 |
+
}
|
| 447 |
+
}
|
| 448 |
+
|
| 449 |
+
if (args.pld_enabled && (int)draft.size() >= 1 && args.temperature == 0.0f) {
|
| 450 |
+
// Batch verify: input = [next_id, draft[0], ..., draft[K-1]]
|
| 451 |
+
std::vector<int32_t> batch_input = { next_id };
|
| 452 |
+
for (auto d : draft) batch_input.push_back(d);
|
| 453 |
+
int S = (int)batch_input.size();
|
| 454 |
+
DeviceBuffer batch_logits;
|
| 455 |
+
if (!runner.decode_batch(batch_input.data(), S, batch_logits)) break;
|
| 456 |
+
|
| 457 |
+
std::vector<uint16_t> blh(S * V);
|
| 458 |
+
if (is_master) ACL_CHECK(aclrtMemcpy(blh.data(), S*V*2, batch_logits.get(), S*V*2, ACL_MEMCPY_DEVICE_TO_HOST));
|
| 459 |
+
|
| 460 |
+
// Accept longest prefix: draft[i] is "candidate" for position past+i+1.
|
| 461 |
+
// blh row i predicts position past+i+1 (follows batch_input[i]).
|
| 462 |
+
// Verify: blh[0].argmax == draft[0]? (i.e., does model agree with draft's first proposal)
|
| 463 |
+
int accept = 0, new_next = next_id;
|
| 464 |
+
if (is_master) {
|
| 465 |
+
for (int i = 0; i < S - 1; i++) {
|
| 466 |
+
int pred = 0; float bv = bf16_to_float(blh[i * V]);
|
| 467 |
+
for (int k = 1; k < V; k++) { float v = bf16_to_float(blh[i*V + k]); if (v > bv) { bv = v; pred = k; } }
|
| 468 |
+
if (pred == (int)draft[i]) accept++;
|
| 469 |
+
else { new_next = pred; break; }
|
| 470 |
+
}
|
| 471 |
+
if (accept == S - 1) {
|
| 472 |
+
// All draft accepted, bonus from last row
|
| 473 |
+
int pred = 0; float bv = bf16_to_float(blh[(S-1) * V]);
|
| 474 |
+
for (int k = 1; k < V; k++) { float v = bf16_to_float(blh[(S-1)*V + k]); if (v > bv) { bv = v; pred = k; } }
|
| 475 |
+
new_next = pred;
|
| 476 |
+
}
|
| 477 |
+
}
|
| 478 |
+
|
| 479 |
+
// Broadcast accept count + new_next across TP ranks
|
| 480 |
+
if (args.tp_size > 1) {
|
| 481 |
+
int32_t packed[2] = { (int32_t)accept, (int32_t)new_next };
|
| 482 |
+
std::vector<int32_t> p(packed, packed + 2);
|
| 483 |
+
broadcast_token_ids(runner, p, args.max_seq, is_master);
|
| 484 |
+
if (p.size() == 2) { accept = p[0]; new_next = p[1]; }
|
| 485 |
+
}
|
| 486 |
+
|
| 487 |
+
// Rewind KV for rejected drafts
|
| 488 |
+
int64_t rewind = (int64_t)(S - 1 - accept); // drafts not accepted (excluding bonus)
|
| 489 |
+
if (rewind > 0) runner.rewind_cache(rewind);
|
| 490 |
+
|
| 491 |
+
// Commit accepted drafts + bonus to hist and emit
|
| 492 |
+
for (int i = 0; i < accept; i++) {
|
| 493 |
+
int tok = (int)draft[i];
|
| 494 |
+
hist.push_back(tok);
|
| 495 |
+
generated.push_back(tok);
|
| 496 |
+
st.decoded++;
|
| 497 |
+
if (is_master && args.stream) { printf("%s", tokenizer.decode(tok).c_str()); fflush(stdout); }
|
| 498 |
+
if (is_eos(tok)) { st.hit_eos = true; break; }
|
| 499 |
+
}
|
| 500 |
+
pld_verifies++; pld_accepted += accept;
|
| 501 |
+
// Track recent accept for adaptive K
|
| 502 |
+
if (args.pld_adaptive) {
|
| 503 |
+
recent_accepts.push_back(accept);
|
| 504 |
+
if ((int)recent_accepts.size() > ADAPT_WINDOW) recent_accepts.erase(recent_accepts.begin());
|
| 505 |
+
}
|
| 506 |
+
if (st.hit_eos) break;
|
| 507 |
+
|
| 508 |
+
// Bonus token (new_next) is also committed
|
| 509 |
+
hist.push_back(new_next);
|
| 510 |
+
generated.push_back(new_next);
|
| 511 |
+
st.decoded++;
|
| 512 |
+
if (is_master && args.stream) { printf("%s", tokenizer.decode(new_next).c_str()); fflush(stdout); }
|
| 513 |
+
if (is_eos(new_next)) { st.hit_eos = true; break; }
|
| 514 |
+
next_id = new_next;
|
| 515 |
+
} else {
|
| 516 |
+
// Normal decode
|
| 517 |
+
DeviceBuffer logits2;
|
| 518 |
+
if (!runner.decode((int32_t)next_id, logits2)) break;
|
| 519 |
+
load_logits(logits2);
|
| 520 |
+
next_id = sample_token(logits_h, V, args.temperature, args.top_k, args.top_p, rng);
|
| 521 |
+
hist.push_back(next_id);
|
| 522 |
+
generated.push_back(next_id);
|
| 523 |
+
st.decoded++;
|
| 524 |
+
if (is_master && args.stream) { printf("%s", tokenizer.decode(next_id).c_str()); fflush(stdout); }
|
| 525 |
+
if (is_eos(next_id)) { st.hit_eos = true; break; }
|
| 526 |
+
}
|
| 527 |
+
// Loop-warn: emit a one-shot warning to stderr if the tail of generated is all-same-token.
|
| 528 |
+
// Does not stop generation (user may want to see what happens) — just flags output quality.
|
| 529 |
+
if (is_master && !loop_warned && args.pld_loop_warn > 0 &&
|
| 530 |
+
(int)generated.size() >= args.pld_loop_warn) {
|
| 531 |
+
int tail = args.pld_loop_warn;
|
| 532 |
+
int anchor = generated[(int)generated.size() - tail];
|
| 533 |
+
bool all_same = true;
|
| 534 |
+
for (int i = (int)generated.size() - tail + 1; i < (int)generated.size(); i++) {
|
| 535 |
+
if (generated[i] != anchor) { all_same = false; break; }
|
| 536 |
+
}
|
| 537 |
+
if (all_same) {
|
| 538 |
+
fprintf(stderr, "\n[warn] %d consecutive identical tokens — likely degeneration loop; output after this point is suspect\n", tail);
|
| 539 |
+
loop_warned = true;
|
| 540 |
+
}
|
| 541 |
+
}
|
| 542 |
+
}
|
| 543 |
+
auto t3 = std::chrono::steady_clock::now();
|
| 544 |
+
st.decode_ms = std::chrono::duration<double, std::milli>(t3 - t2).count();
|
| 545 |
+
if (is_master && args.pld_enabled) {
|
| 546 |
+
if (pld_verifies > 0) {
|
| 547 |
+
fprintf(stderr, "\n[pld] %d verifies, %d drafts accepted, avg=%.2f",
|
| 548 |
+
pld_verifies, pld_accepted, (double)pld_accepted / pld_verifies);
|
| 549 |
+
} else {
|
| 550 |
+
fprintf(stderr, "\n[pld] 0 verifies (all drafts blocked or none found)");
|
| 551 |
+
}
|
| 552 |
+
if (args.pld_guard && (pld_rej_lowdist + pld_rej_tailecho) > 0) {
|
| 553 |
+
fprintf(stderr, "; guard rejections: low-distinct=%d tail-echo=%d",
|
| 554 |
+
pld_rej_lowdist, pld_rej_tailecho);
|
| 555 |
+
}
|
| 556 |
+
fprintf(stderr, "\n");
|
| 557 |
+
}
|
| 558 |
+
|
| 559 |
+
if (is_master) {
|
| 560 |
+
if (args.stream) { printf("\n"); fflush(stdout); }
|
| 561 |
+
else {
|
| 562 |
+
std::string text = tokenizer.decode(generated);
|
| 563 |
+
printf("%s%s\n", args.chat_template ? "" : prompt.c_str(), text.c_str());
|
| 564 |
+
}
|
| 565 |
+
}
|
| 566 |
+
return st;
|
| 567 |
+
}
|
| 568 |
+
|
| 569 |
+
static bool load_file(const std::string& path, std::string& out) {
|
| 570 |
+
FILE* f = fopen(path.c_str(), "rb");
|
| 571 |
+
if (!f) { fprintf(stderr, "[cli] cannot open %s\n", path.c_str()); return false; }
|
| 572 |
+
fseek(f, 0, SEEK_END); long sz = ftell(f); fseek(f, 0, SEEK_SET);
|
| 573 |
+
out.resize(sz);
|
| 574 |
+
size_t n = fread(out.data(), 1, sz, f);
|
| 575 |
+
fclose(f);
|
| 576 |
+
if ((long)n != sz) { fprintf(stderr, "[cli] short read from %s\n", path.c_str()); return false; }
|
| 577 |
+
// Strip a single trailing newline (common in text files)
|
| 578 |
+
if (!out.empty() && out.back() == '\n') out.pop_back();
|
| 579 |
+
return true;
|
| 580 |
+
}
|
| 581 |
+
|
| 582 |
+
int main(int argc, char** argv) {
|
| 583 |
+
Args args;
|
| 584 |
+
if (!parse_args(argc, argv, args)) return 1;
|
| 585 |
+
|
| 586 |
+
// --prompt-file overrides --prompt
|
| 587 |
+
if (!args.prompt_file.empty()) {
|
| 588 |
+
if (!load_file(args.prompt_file, args.prompt)) return 1;
|
| 589 |
+
}
|
| 590 |
+
|
| 591 |
+
const bool is_master = (args.tp_rank == 0);
|
| 592 |
+
std::mt19937 rng(args.seed ? args.seed :
|
| 593 |
+
(uint64_t)std::chrono::steady_clock::now().time_since_epoch().count());
|
| 594 |
+
|
| 595 |
+
if (is_master) {
|
| 596 |
+
printf("[cli] model=%s\n", args.model_dir.c_str());
|
| 597 |
+
printf("[cli] tp=%d n_predict=%d temp=%.2f top_k=%d top_p=%.2f chat=%d interactive=%d\n",
|
| 598 |
+
args.tp_size, args.n_predict, args.temperature, args.top_k, args.top_p,
|
| 599 |
+
args.chat_template, args.interactive);
|
| 600 |
+
fflush(stdout);
|
| 601 |
+
}
|
| 602 |
+
|
| 603 |
+
Tokenizer tokenizer;
|
| 604 |
+
if (!tokenizer.load(args.vocab_path)) {
|
| 605 |
+
fprintf(stderr, "[cli] failed to load vocab %s\n", args.vocab_path.c_str()); return 1;
|
| 606 |
+
}
|
| 607 |
+
|
| 608 |
+
Runner runner;
|
| 609 |
+
int num_layers = args.num_layers;
|
| 610 |
+
if (num_layers == 0) {
|
| 611 |
+
ModelConfig probe;
|
| 612 |
+
if (!probe.load_from_json(args.model_dir + "/config.json")) return 1;
|
| 613 |
+
num_layers = (int)probe.num_hidden_layers;
|
| 614 |
+
}
|
| 615 |
+
if (!runner.init(args.model_dir, args.tp_size, args.tp_rank,
|
| 616 |
+
num_layers, args.max_seq, args.device_id)) return 1;
|
| 617 |
+
if (const char* p = std::getenv("LCA_PROFILE"); p && std::atoi(p) != 0) {
|
| 618 |
+
runner.profile_enabled = true;
|
| 619 |
+
}
|
| 620 |
+
// Warmup: cut cold-start latency. Controlled via LCA_WARMUP env (default 0 to keep behavior).
|
| 621 |
+
if (const char* w = std::getenv("LCA_WARMUP"); w) {
|
| 622 |
+
int n = std::atoi(w);
|
| 623 |
+
if (n > 0) runner.warmup(n);
|
| 624 |
+
}
|
| 625 |
+
|
| 626 |
+
if (args.interactive) {
|
| 627 |
+
const bool multi_turn = args.chat_template && !args.reset_each_turn;
|
| 628 |
+
if (is_master) {
|
| 629 |
+
printf("\n[cli] === interactive mode ===\n");
|
| 630 |
+
if (multi_turn) {
|
| 631 |
+
printf("[cli] multi-turn chat (KV cache preserved). Commands: 'quit', 'reset'.\n");
|
| 632 |
+
if (!args.system_prompt.empty()) {
|
| 633 |
+
printf("[cli] system: %s\n", args.system_prompt.c_str());
|
| 634 |
+
}
|
| 635 |
+
} else {
|
| 636 |
+
printf("[cli] stateless mode (KV cache reset each turn). Command: 'quit'.\n");
|
| 637 |
+
if (!args.chat_template) {
|
| 638 |
+
printf("[cli] (hint: add --chat for multi-turn conversational memory)\n");
|
| 639 |
+
}
|
| 640 |
+
}
|
| 641 |
+
fflush(stdout);
|
| 642 |
+
}
|
| 643 |
+
|
| 644 |
+
// Conversation history: accumulated (role, content) pairs. System prompt seeded if present.
|
| 645 |
+
std::vector<std::pair<std::string, std::string>> conversation;
|
| 646 |
+
if (multi_turn && !args.system_prompt.empty()) {
|
| 647 |
+
conversation.emplace_back("system", args.system_prompt);
|
| 648 |
+
}
|
| 649 |
+
|
| 650 |
+
auto* hccl_ctx = runner_hccl_ctx_shim(runner);
|
| 651 |
+
|
| 652 |
+
// Signal types (broadcast as int32): 0 = normal turn, 1 = quit, 2 = reset.
|
| 653 |
+
auto broadcast_signal = [&](int32_t sig)->int32_t {
|
| 654 |
+
if (args.tp_size <= 1) return sig;
|
| 655 |
+
DeviceBuffer s(4);
|
| 656 |
+
ACL_CHECK(aclrtMemcpy(s.get(), 4, &sig, 4, ACL_MEMCPY_HOST_TO_DEVICE));
|
| 657 |
+
hccl_broadcast(*hccl_ctx, s.get(), 1, HCCL_DATA_TYPE_INT32, 0, runner.stream());
|
| 658 |
+
ACL_CHECK(aclrtSynchronizeStream(runner.stream()));
|
| 659 |
+
int32_t r; ACL_CHECK(aclrtMemcpy(&r, 4, s.get(), 4, ACL_MEMCPY_DEVICE_TO_HOST));
|
| 660 |
+
return r;
|
| 661 |
+
};
|
| 662 |
+
|
| 663 |
+
while (true) {
|
| 664 |
+
std::string prompt;
|
| 665 |
+
int32_t sig = 0;
|
| 666 |
+
if (is_master) {
|
| 667 |
+
printf("\n> "); fflush(stdout);
|
| 668 |
+
if (!std::getline(std::cin, prompt)) sig = 1;
|
| 669 |
+
else if (prompt == "quit" || prompt == "exit") sig = 1;
|
| 670 |
+
else if (prompt == "reset") sig = 2;
|
| 671 |
+
else if (prompt.empty()) sig = 3; // skip
|
| 672 |
+
}
|
| 673 |
+
sig = broadcast_signal(sig);
|
| 674 |
+
if (sig == 1) break;
|
| 675 |
+
if (sig == 2) {
|
| 676 |
+
runner.reset_cache();
|
| 677 |
+
conversation.clear();
|
| 678 |
+
if (multi_turn && !args.system_prompt.empty())
|
| 679 |
+
conversation.emplace_back("system", args.system_prompt);
|
| 680 |
+
if (is_master) { printf("[cli] (cache + conversation reset)\n"); fflush(stdout); }
|
| 681 |
+
continue;
|
| 682 |
+
}
|
| 683 |
+
if (sig == 3) continue;
|
| 684 |
+
|
| 685 |
+
TurnStats st;
|
| 686 |
+
if (multi_turn) {
|
| 687 |
+
// Append user message and tokenize full conversation. Prefill DELTA only.
|
| 688 |
+
if (is_master) conversation.emplace_back("user", prompt);
|
| 689 |
+
// Also ranks 1..N-1 need to track conversation (needed for correct delta count on
|
| 690 |
+
// subsequent turns if TP ever tokenizes per-rank — currently rank 0 tokenizes).
|
| 691 |
+
std::vector<int32_t> full_ids;
|
| 692 |
+
if (is_master) {
|
| 693 |
+
auto raw = tokenizer.encode_conversation_via_python(args.model_dir, conversation, /*gen_prompt=*/true);
|
| 694 |
+
full_ids.reserve(raw.size());
|
| 695 |
+
for (int v : raw) full_ids.push_back((int32_t)v);
|
| 696 |
+
}
|
| 697 |
+
// Broadcast full_ids (variable-length). Use the same shim as broadcast_token_ids.
|
| 698 |
+
if (args.tp_size > 1) {
|
| 699 |
+
if (!broadcast_token_ids(runner, full_ids, args.max_seq, is_master)) break;
|
| 700 |
+
}
|
| 701 |
+
if (full_ids.empty()) { if (is_master) printf("[cli] tokenize failed\n"); continue; }
|
| 702 |
+
|
| 703 |
+
int64_t past = runner.past_len();
|
| 704 |
+
if ((int64_t)full_ids.size() < past) { runner.reset_cache(); past = 0; }
|
| 705 |
+
std::vector<int32_t> delta(full_ids.begin() + past, full_ids.end());
|
| 706 |
+
if (delta.empty()) {
|
| 707 |
+
if (is_master) printf("[cli] (no new tokens)\n");
|
| 708 |
+
continue;
|
| 709 |
+
}
|
| 710 |
+
// Overflow check — simple policy: warn + auto-reset if the turn + generation
|
| 711 |
+
// would exceed max_seq. Conversation history is cleared (except --system) so
|
| 712 |
+
// the user's current prompt still fits.
|
| 713 |
+
if ((int64_t)(past + delta.size()) + args.n_predict > args.max_seq) {
|
| 714 |
+
if (is_master) {
|
| 715 |
+
fprintf(stderr, "[cli] context %ld + gen %d > max_seq %d — auto-resetting\n",
|
| 716 |
+
(long)(past + delta.size()), args.n_predict, args.max_seq);
|
| 717 |
+
}
|
| 718 |
+
runner.reset_cache();
|
| 719 |
+
// Rebuild conversation: keep only system + current user turn.
|
| 720 |
+
if (is_master) {
|
| 721 |
+
std::vector<std::pair<std::string, std::string>> fresh;
|
| 722 |
+
for (auto& m : conversation) if (m.first == "system") fresh.push_back(m);
|
| 723 |
+
if (!conversation.empty() && conversation.back().first == "user") {
|
| 724 |
+
fresh.push_back(conversation.back());
|
| 725 |
+
}
|
| 726 |
+
conversation = std::move(fresh);
|
| 727 |
+
auto raw = tokenizer.encode_conversation_via_python(args.model_dir, conversation, true);
|
| 728 |
+
full_ids.clear();
|
| 729 |
+
for (int v : raw) full_ids.push_back((int32_t)v);
|
| 730 |
+
}
|
| 731 |
+
if (args.tp_size > 1) {
|
| 732 |
+
if (!broadcast_token_ids(runner, full_ids, args.max_seq, is_master)) break;
|
| 733 |
+
}
|
| 734 |
+
delta.assign(full_ids.begin(), full_ids.end());
|
| 735 |
+
past = 0;
|
| 736 |
+
}
|
| 737 |
+
|
| 738 |
+
// --- Prefill the delta ---
|
| 739 |
+
st.n_prompt = (int)delta.size();
|
| 740 |
+
auto t0 = std::chrono::steady_clock::now();
|
| 741 |
+
DeviceBuffer logits;
|
| 742 |
+
if (!runner.prefill(delta.data(), (int64_t)delta.size(), logits)) break;
|
| 743 |
+
auto t1 = std::chrono::steady_clock::now();
|
| 744 |
+
st.prefill_ms = std::chrono::duration<double, std::milli>(t1 - t0).count();
|
| 745 |
+
|
| 746 |
+
const int64_t V = runner.cfg().vocab_size;
|
| 747 |
+
std::vector<uint16_t> logits_h(V);
|
| 748 |
+
auto load_logits = [&](DeviceBuffer& buf) {
|
| 749 |
+
ACL_CHECK(aclrtMemcpy(logits_h.data(), V*2, buf.get(), V*2, ACL_MEMCPY_DEVICE_TO_HOST));
|
| 750 |
+
};
|
| 751 |
+
auto is_eos = [&](int id) {
|
| 752 |
+
for (int e : args.eos_ids) if (id == e) return true;
|
| 753 |
+
return false;
|
| 754 |
+
};
|
| 755 |
+
load_logits(logits);
|
| 756 |
+
int next_id = sample_token(logits_h, V, args.temperature, args.top_k, args.top_p, rng);
|
| 757 |
+
|
| 758 |
+
std::vector<int> assistant_ids = { next_id };
|
| 759 |
+
if (is_master) { printf("%s", tokenizer.decode(next_id).c_str()); fflush(stdout); }
|
| 760 |
+
st.hit_eos = is_eos(next_id);
|
| 761 |
+
|
| 762 |
+
auto t2 = std::chrono::steady_clock::now();
|
| 763 |
+
for (int step = 1; step < args.n_predict && !st.hit_eos; step++) {
|
| 764 |
+
DeviceBuffer logits2;
|
| 765 |
+
if (!runner.decode((int32_t)next_id, logits2)) break;
|
| 766 |
+
load_logits(logits2);
|
| 767 |
+
next_id = sample_token(logits_h, V, args.temperature, args.top_k, args.top_p, rng);
|
| 768 |
+
assistant_ids.push_back(next_id);
|
| 769 |
+
st.decoded++;
|
| 770 |
+
if (is_master) { printf("%s", tokenizer.decode(next_id).c_str()); fflush(stdout); }
|
| 771 |
+
if (is_eos(next_id)) { st.hit_eos = true; break; }
|
| 772 |
+
}
|
| 773 |
+
auto t3 = std::chrono::steady_clock::now();
|
| 774 |
+
st.decode_ms = std::chrono::duration<double, std::milli>(t3 - t2).count();
|
| 775 |
+
if (is_master) { printf("\n"); fflush(stdout); }
|
| 776 |
+
|
| 777 |
+
// Record assistant reply in conversation (strip trailing EOS before decode,
|
| 778 |
+
// and trim incomplete UTF-8 tail if generation was cut mid-codepoint).
|
| 779 |
+
if (is_master) {
|
| 780 |
+
std::vector<int> content_ids;
|
| 781 |
+
for (int id : assistant_ids) { if (is_eos(id)) break; content_ids.push_back(id); }
|
| 782 |
+
conversation.emplace_back("assistant", utf8_trim_incomplete(tokenizer.decode(content_ids)));
|
| 783 |
+
}
|
| 784 |
+
} else {
|
| 785 |
+
// Stateless: reset cache, one-shot prompt
|
| 786 |
+
runner.reset_cache();
|
| 787 |
+
st = run_turn(runner, tokenizer, args, prompt, rng, is_master);
|
| 788 |
+
}
|
| 789 |
+
|
| 790 |
+
if (is_master) {
|
| 791 |
+
double tgs = (st.decode_ms > 0) ? (st.decoded * 1000.0 / st.decode_ms) : 0.0;
|
| 792 |
+
printf("[perf] prefill %d tok %.0fms decode %d tok %.0fms = %.2f t/s%s past_len=%ld\n",
|
| 793 |
+
st.n_prompt, st.prefill_ms, st.decoded, st.decode_ms, tgs,
|
| 794 |
+
st.hit_eos ? " (EOS)" : "", runner.past_len());
|
| 795 |
+
fflush(stdout);
|
| 796 |
+
}
|
| 797 |
+
}
|
| 798 |
+
if (is_master) printf("[cli] bye\n");
|
| 799 |
+
return 0;
|
| 800 |
+
}
|
| 801 |
+
|
| 802 |
+
// One-shot mode
|
| 803 |
+
TurnStats st = run_turn(runner, tokenizer, args, args.prompt, rng, is_master);
|
| 804 |
+
if (is_master) runner.print_profile_summary();
|
| 805 |
+
if (is_master) {
|
| 806 |
+
if (st.hit_eos) printf("[cli] (hit EOS)\n");
|
| 807 |
+
printf("\n[perf] prefill: %.1fms for %d tokens = %.2f t/s\n",
|
| 808 |
+
st.prefill_ms, st.n_prompt,
|
| 809 |
+
(st.prefill_ms > 0) ? (st.n_prompt * 1000.0 / st.prefill_ms) : 0.0);
|
| 810 |
+
if (st.decoded > 0) {
|
| 811 |
+
printf("[perf] decode : %.1fms for %d tokens = %.2f t/s (TG)\n",
|
| 812 |
+
st.decode_ms, st.decoded, (st.decoded * 1000.0) / st.decode_ms);
|
| 813 |
+
}
|
| 814 |
+
}
|
| 815 |
+
return 0;
|
| 816 |
+
}
|
|
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include "model_config.h"
|
| 2 |
+
|
| 3 |
+
#include <cstdio>
|
| 4 |
+
#include <fstream>
|
| 5 |
+
#include <sstream>
|
| 6 |
+
|
| 7 |
+
#include "json.hpp"
|
| 8 |
+
using json = nlohmann::json;
|
| 9 |
+
|
| 10 |
+
bool ModelConfig::load_from_json(const std::string& path) {
|
| 11 |
+
std::ifstream f(path);
|
| 12 |
+
if (!f) {
|
| 13 |
+
fprintf(stderr, "ModelConfig: cannot open %s\n", path.c_str());
|
| 14 |
+
return false;
|
| 15 |
+
}
|
| 16 |
+
json j;
|
| 17 |
+
try { f >> j; } catch (std::exception& e) {
|
| 18 |
+
fprintf(stderr, "ModelConfig: bad json: %s\n", e.what());
|
| 19 |
+
return false;
|
| 20 |
+
}
|
| 21 |
+
|
| 22 |
+
auto get = [&](const char* k, auto def) {
|
| 23 |
+
if (j.contains(k) && !j[k].is_null()) return j[k].get<decltype(def)>();
|
| 24 |
+
return def;
|
| 25 |
+
};
|
| 26 |
+
|
| 27 |
+
vocab_size = get("vocab_size", (int64_t)0);
|
| 28 |
+
hidden_size = get("hidden_size", (int64_t)0);
|
| 29 |
+
intermediate_size = get("intermediate_size", (int64_t)0);
|
| 30 |
+
moe_intermediate_size = get("moe_intermediate_size", (int64_t)0);
|
| 31 |
+
num_hidden_layers = get("num_hidden_layers", (int64_t)0);
|
| 32 |
+
num_attention_heads = get("num_attention_heads", (int64_t)0);
|
| 33 |
+
num_key_value_heads = get("num_key_value_heads", (int64_t)0);
|
| 34 |
+
head_dim = get("head_dim", (int64_t)0);
|
| 35 |
+
num_experts = get("num_experts", (int64_t)0);
|
| 36 |
+
num_experts_per_tok = get("num_experts_per_tok", (int64_t)0);
|
| 37 |
+
max_position_embeddings = get("max_position_embeddings", (int64_t)0);
|
| 38 |
+
rope_theta = (float)get("rope_theta", (double)10000.0);
|
| 39 |
+
rms_norm_eps = (float)get("rms_norm_eps", (double)1e-6);
|
| 40 |
+
norm_topk_prob = get("norm_topk_prob", true);
|
| 41 |
+
tie_word_embeddings = get("tie_word_embeddings", false);
|
| 42 |
+
bos_token_id = get("bos_token_id", (int64_t)0);
|
| 43 |
+
eos_token_id = get("eos_token_id", (int64_t)0);
|
| 44 |
+
|
| 45 |
+
// Sanity
|
| 46 |
+
if (num_attention_heads == 0 || head_dim == 0 || hidden_size == 0) {
|
| 47 |
+
fprintf(stderr, "ModelConfig: required fields missing\n");
|
| 48 |
+
return false;
|
| 49 |
+
}
|
| 50 |
+
return true;
|
| 51 |
+
}
|
| 52 |
+
|
| 53 |
+
void ModelConfig::compute_derived(int tps, int tpr) {
|
| 54 |
+
tp_size = tps;
|
| 55 |
+
tp_rank = tpr;
|
| 56 |
+
|
| 57 |
+
// Attention Q: split by head
|
| 58 |
+
if (num_attention_heads % tp_size != 0) {
|
| 59 |
+
fprintf(stderr, "WARN: num_attention_heads=%ld not divisible by tp_size=%d\n",
|
| 60 |
+
num_attention_heads, tp_size);
|
| 61 |
+
}
|
| 62 |
+
n_heads_per_rank = num_attention_heads / tp_size;
|
| 63 |
+
q_dim_per_rank = n_heads_per_rank * head_dim;
|
| 64 |
+
|
| 65 |
+
// Attention KV: GQA sharding.
|
| 66 |
+
// Case A (tp_size <= num_kv_heads): split KV heads across ranks.
|
| 67 |
+
// n_kv_heads_per_rank = num_kv_heads / tp_size
|
| 68 |
+
// Case B (tp_size > num_kv_heads): each rank gets ONE kv head shared by multiple ranks.
|
| 69 |
+
// Ranks in the same "group" share one kv head (ratio = tp_size / num_kv_heads).
|
| 70 |
+
// n_kv_heads_per_rank = 1
|
| 71 |
+
// kv_head_idx_for_rank = tp_rank / (tp_size / num_kv_heads)
|
| 72 |
+
// This matches the GQA semantics: each group of (num_q_heads / num_kv_heads) Q heads
|
| 73 |
+
// shares one KV head. FIAS is given matched Hq (rank-local Q heads) and Hkv=1.
|
| 74 |
+
if (tp_size <= num_key_value_heads && num_key_value_heads % tp_size == 0) {
|
| 75 |
+
n_kv_heads_per_rank = num_key_value_heads / tp_size;
|
| 76 |
+
} else if (tp_size % num_key_value_heads == 0) {
|
| 77 |
+
n_kv_heads_per_rank = 1;
|
| 78 |
+
} else {
|
| 79 |
+
fprintf(stderr, "WARN: non-standard TP/KV head ratio: tp=%d kv=%ld — falling back to replicate-all\n",
|
| 80 |
+
tp_size, num_key_value_heads);
|
| 81 |
+
n_kv_heads_per_rank = num_key_value_heads;
|
| 82 |
+
}
|
| 83 |
+
kv_dim_per_rank = n_kv_heads_per_rank * head_dim;
|
| 84 |
+
|
| 85 |
+
// MoE intermediate dim split
|
| 86 |
+
if (moe_intermediate_size % tp_size != 0) {
|
| 87 |
+
fprintf(stderr, "WARN: moe_intermediate_size=%ld not divisible by tp_size=%d\n",
|
| 88 |
+
moe_intermediate_size, tp_size);
|
| 89 |
+
}
|
| 90 |
+
i_per_rank = moe_intermediate_size / tp_size;
|
| 91 |
+
}
|
| 92 |
+
|
| 93 |
+
std::string ModelConfig::describe() const {
|
| 94 |
+
std::ostringstream os;
|
| 95 |
+
os << "Qwen3MoE config:\n"
|
| 96 |
+
<< " vocab_size = " << vocab_size << "\n"
|
| 97 |
+
<< " hidden_size = " << hidden_size << "\n"
|
| 98 |
+
<< " num_hidden_layers = " << num_hidden_layers << "\n"
|
| 99 |
+
<< " num_attention_heads = " << num_attention_heads << "\n"
|
| 100 |
+
<< " num_key_value_heads = " << num_key_value_heads << "\n"
|
| 101 |
+
<< " head_dim = " << head_dim << "\n"
|
| 102 |
+
<< " num_experts = " << num_experts << "\n"
|
| 103 |
+
<< " num_experts_per_tok = " << num_experts_per_tok << "\n"
|
| 104 |
+
<< " moe_intermediate_size = " << moe_intermediate_size << "\n"
|
| 105 |
+
<< " rope_theta = " << rope_theta << "\n"
|
| 106 |
+
<< " rms_norm_eps = " << rms_norm_eps << "\n"
|
| 107 |
+
<< " max_pos_embeddings = " << max_position_embeddings << "\n"
|
| 108 |
+
<< "TP rank " << tp_rank << " / " << tp_size << " derived:\n"
|
| 109 |
+
<< " n_heads_per_rank = " << n_heads_per_rank << "\n"
|
| 110 |
+
<< " q_dim_per_rank = " << q_dim_per_rank << "\n"
|
| 111 |
+
<< " n_kv_heads_per_rank = " << n_kv_heads_per_rank << "\n"
|
| 112 |
+
<< " kv_dim_per_rank = " << kv_dim_per_rank << "\n"
|
| 113 |
+
<< " i_per_rank = " << i_per_rank << "\n";
|
| 114 |
+
return os.str();
|
| 115 |
+
}
|
|
@@ -0,0 +1,428 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include "runner.h"
|
| 2 |
+
|
| 3 |
+
#include <chrono>
|
| 4 |
+
#include <cstdio>
|
| 5 |
+
#include <cstring>
|
| 6 |
+
|
| 7 |
+
// Expose HCCL context for the CLI broadcast helper.
|
| 8 |
+
HcclCtx* runner_hccl_ctx_shim(Runner& r) { return &r.hccl_ctx(); }
|
| 9 |
+
|
| 10 |
+
bool Runner::init(const std::string& model_dir, int tp_size, int tp_rank,
|
| 11 |
+
int num_layers_to_load, int64_t max_seq, int device_id) {
|
| 12 |
+
if (!cfg_.load_from_json(model_dir + "/config.json")) return false;
|
| 13 |
+
cfg_.compute_derived(tp_size, tp_rank);
|
| 14 |
+
if (num_layers_to_load < 1 || num_layers_to_load > (int)cfg_.num_hidden_layers) {
|
| 15 |
+
fprintf(stderr, "runner: invalid num_layers %d (max %ld)\n",
|
| 16 |
+
num_layers_to_load, cfg_.num_hidden_layers);
|
| 17 |
+
return false;
|
| 18 |
+
}
|
| 19 |
+
num_layers_ = num_layers_to_load;
|
| 20 |
+
max_seq_ = max_seq;
|
| 21 |
+
|
| 22 |
+
if (!st_.open(model_dir)) return false;
|
| 23 |
+
rt_.init(device_id);
|
| 24 |
+
|
| 25 |
+
// HCCL init (no-op if tp_size == 1)
|
| 26 |
+
if (!hccl_init(hccl_ctx_, tp_size, tp_rank)) {
|
| 27 |
+
fprintf(stderr, "runner: HCCL init failed\n");
|
| 28 |
+
return false;
|
| 29 |
+
}
|
| 30 |
+
|
| 31 |
+
DeviceWeightsLoader dw(st_, cfg_);
|
| 32 |
+
printf("runner: loading shared weights (embed, lm_head, final_norm)...\n");
|
| 33 |
+
if (!dw.load_shared(shared_)) return false;
|
| 34 |
+
|
| 35 |
+
attn_.resize(num_layers_);
|
| 36 |
+
moe_.resize(num_layers_);
|
| 37 |
+
k_cache_.resize(num_layers_);
|
| 38 |
+
v_cache_.resize(num_layers_);
|
| 39 |
+
|
| 40 |
+
const int64_t KV_DIM = cfg_.n_kv_heads_per_rank * cfg_.head_dim;
|
| 41 |
+
for (int L = 0; L < num_layers_; L++) {
|
| 42 |
+
printf("runner: loading layer %d/%d...\n", L + 1, num_layers_);
|
| 43 |
+
if (!dw.load_attention(L, attn_[L])) return false;
|
| 44 |
+
if (!dw.load_moe(L, rt_.stream(), moe_[L])) return false;
|
| 45 |
+
k_cache_[L].alloc(max_seq_ * KV_DIM * 2);
|
| 46 |
+
v_cache_[L].alloc(max_seq_ * KV_DIM * 2);
|
| 47 |
+
}
|
| 48 |
+
rt_.sync();
|
| 49 |
+
|
| 50 |
+
// Prefill mask (2048x2048 bool causal)
|
| 51 |
+
const int64_t MASK = 2048;
|
| 52 |
+
std::vector<uint8_t> mh(MASK * MASK, 0);
|
| 53 |
+
for (int i = 0; i < MASK; i++)
|
| 54 |
+
for (int j = i+1; j < MASK; j++) mh[i*MASK + j] = 1;
|
| 55 |
+
prefill_mask_dev_.alloc(MASK * MASK);
|
| 56 |
+
ACL_CHECK(aclrtMemcpy(prefill_mask_dev_.get(), MASK*MASK, mh.data(), MASK*MASK, ACL_MEMCPY_HOST_TO_DEVICE));
|
| 57 |
+
|
| 58 |
+
// Pre-compute RoPE cos/sin table once (covers all positions up to max_seq_)
|
| 59 |
+
rope_cache_build(rope_cache_, max_seq_, cfg_.head_dim, cfg_.rope_theta);
|
| 60 |
+
|
| 61 |
+
past_len_ = 0;
|
| 62 |
+
cur_S_capacity_ = 0;
|
| 63 |
+
return true;
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
+
static void ensure_sc_(DeviceBuffer& buf, size_t needed) {
|
| 67 |
+
if (buf.size < needed) buf.alloc(needed);
|
| 68 |
+
}
|
| 69 |
+
|
| 70 |
+
static void ensure_all_scratch_(Runner* self, int64_t S, const ModelConfig& cfg,
|
| 71 |
+
DeviceBuffer& q_sc, DeviceBuffer& k_sc, DeviceBuffer& v_sc,
|
| 72 |
+
DeviceBuffer& xn_sc, DeviceBuffer& rstd_sc, DeviceBuffer& rope_sc,
|
| 73 |
+
DeviceBuffer& attn_fias_sc, DeviceBuffer& attn_out_sc,
|
| 74 |
+
DeviceBuffer& moe_xn, DeviceBuffer& moe_rstd, DeviceBuffer& moe_logits,
|
| 75 |
+
DeviceBuffer& moe_topk_w, DeviceBuffer& moe_topk_idx, DeviceBuffer& moe_row_idx,
|
| 76 |
+
DeviceBuffer& moe_ex_x, DeviceBuffer& moe_ex_ri, DeviceBuffer& moe_tpe,
|
| 77 |
+
DeviceBuffer& moe_fwd,
|
| 78 |
+
DeviceBuffer& moe_gate, DeviceBuffer& moe_up, DeviceBuffer& moe_down,
|
| 79 |
+
DeviceBuffer& moe_packed, DeviceBuffer& moe_weighted, DeviceBuffer& moe_out,
|
| 80 |
+
DeviceBuffer& moe_norm_sum,
|
| 81 |
+
DeviceBuffer& x_buf_a, DeviceBuffer& x_buf_b) {
|
| 82 |
+
(void)self;
|
| 83 |
+
const int64_t D = cfg.hidden_size;
|
| 84 |
+
const int64_t Hq = cfg.n_heads_per_rank, Hkv = cfg.n_kv_heads_per_rank;
|
| 85 |
+
const int64_t Dh = cfg.head_dim;
|
| 86 |
+
const int64_t Q_DIM = Hq * Dh;
|
| 87 |
+
const int64_t KV_DIM = Hkv * Dh;
|
| 88 |
+
const int64_t I = cfg.i_per_rank, E = cfg.num_experts, K = cfg.num_experts_per_tok;
|
| 89 |
+
const int64_t TOTAL = S * K;
|
| 90 |
+
|
| 91 |
+
ensure_sc_(q_sc, S * Q_DIM * 2);
|
| 92 |
+
ensure_sc_(k_sc, S * KV_DIM * 2);
|
| 93 |
+
ensure_sc_(v_sc, S * KV_DIM * 2);
|
| 94 |
+
ensure_sc_(xn_sc, S * D * 2);
|
| 95 |
+
ensure_sc_(rstd_sc, S * std::max(Hq, Hkv) * 4);
|
| 96 |
+
ensure_sc_(rope_sc, 1 * S * Hq * Dh * 2);
|
| 97 |
+
ensure_sc_(attn_fias_sc, S * Q_DIM * 2);
|
| 98 |
+
ensure_sc_(attn_out_sc, S * D * 2);
|
| 99 |
+
|
| 100 |
+
ensure_sc_(moe_xn, S * D * 2);
|
| 101 |
+
ensure_sc_(moe_rstd, S * 4);
|
| 102 |
+
ensure_sc_(moe_logits, S * E * 2);
|
| 103 |
+
ensure_sc_(moe_topk_w, S * K * 2);
|
| 104 |
+
ensure_sc_(moe_topk_idx, S * K * 4);
|
| 105 |
+
ensure_sc_(moe_row_idx, S * K * 4);
|
| 106 |
+
ensure_sc_(moe_ex_x, TOTAL * D * 2);
|
| 107 |
+
ensure_sc_(moe_ex_ri, TOTAL * 4);
|
| 108 |
+
ensure_sc_(moe_tpe, E * 8);
|
| 109 |
+
ensure_sc_(moe_fwd, TOTAL * 8);
|
| 110 |
+
ensure_sc_(moe_gate, TOTAL * I * 2);
|
| 111 |
+
ensure_sc_(moe_up, TOTAL * I * 2);
|
| 112 |
+
ensure_sc_(moe_down, TOTAL * D * 2);
|
| 113 |
+
ensure_sc_(moe_packed, TOTAL * D * 2);
|
| 114 |
+
ensure_sc_(moe_weighted, S * K * D * 2);
|
| 115 |
+
ensure_sc_(moe_out, S * D * 2);
|
| 116 |
+
ensure_sc_(moe_norm_sum, S * 2);
|
| 117 |
+
|
| 118 |
+
ensure_sc_(x_buf_a, S * D * 2);
|
| 119 |
+
ensure_sc_(x_buf_b, S * D * 2);
|
| 120 |
+
}
|
| 121 |
+
|
| 122 |
+
void Runner::layer_forward_(int layer_idx, int64_t S, void* x_in, void* x_out, bool batch_decode_mode) {
|
| 123 |
+
const int64_t D = cfg_.hidden_size;
|
| 124 |
+
|
| 125 |
+
// Attention mask selection:
|
| 126 |
+
// prefill (S>1, past=0): 2048×2048 upper-tri + sparse_mode=3 (FIAS internal causal)
|
| 127 |
+
// decode (S==1): mask=nullptr + sparse_mode=0 (single query sees all cache)
|
| 128 |
+
// batch decode (S>1, past>0): S × (past+S) causal-with-past + sparse_mode=0
|
| 129 |
+
aclTensor* mask = nullptr;
|
| 130 |
+
int64_t sparse_mode = -1; // auto
|
| 131 |
+
AclTensorPtr t_mask_ptr;
|
| 132 |
+
if (batch_decode_mode) {
|
| 133 |
+
build_batch_decode_mask_(S);
|
| 134 |
+
int64_t kv_len = past_len_ + S;
|
| 135 |
+
t_mask_ptr = make_contig_tensor(batch_mask_dev_.get(), ACL_BOOL, {1, 1, S, kv_len});
|
| 136 |
+
mask = t_mask_ptr.get();
|
| 137 |
+
sparse_mode = 0;
|
| 138 |
+
} else if (S > 1) {
|
| 139 |
+
// Pure prefill from past=0
|
| 140 |
+
t_mask_ptr = make_contig_tensor(prefill_mask_dev_.get(), ACL_BOOL, {1, 1, 2048, 2048});
|
| 141 |
+
mask = t_mask_ptr.get();
|
| 142 |
+
sparse_mode = 3;
|
| 143 |
+
}
|
| 144 |
+
// else: S=1 decode, mask=nullptr, sparse_mode=0 (auto)
|
| 145 |
+
|
| 146 |
+
attention_forward(
|
| 147 |
+
rt_.stream(), cfg_, attn_[layer_idx],
|
| 148 |
+
x_in, S, past_len_,
|
| 149 |
+
k_cache_[layer_idx].get(), v_cache_[layer_idx].get(), max_seq_,
|
| 150 |
+
mask,
|
| 151 |
+
q_sc_.get(), k_sc_.get(), v_sc_.get(),
|
| 152 |
+
xn_sc_.get(), rstd_sc_.get(), rope_sc_.get(),
|
| 153 |
+
attn_fias_sc_.get(),
|
| 154 |
+
attn_out_sc_.get(),
|
| 155 |
+
(hccl_ctx_.tp_size > 1) ? &hccl_ctx_ : nullptr,
|
| 156 |
+
&rope_cache_,
|
| 157 |
+
sparse_mode);
|
| 158 |
+
|
| 159 |
+
// x1 = x_in + attn_out (residual)
|
| 160 |
+
auto t_x_in = make_contig_tensor(x_in, ACL_BF16, {S, D});
|
| 161 |
+
auto t_attn_out= make_contig_tensor(attn_out_sc_.get(), ACL_BF16, {S, D});
|
| 162 |
+
auto t_x1 = make_contig_tensor(x_buf_a_.get(), ACL_BF16, {S, D});
|
| 163 |
+
{
|
| 164 |
+
float a = 1.0f; aclScalar* al = aclCreateScalar(&a, ACL_FLOAT);
|
| 165 |
+
uint64_t ws = 0; aclOpExecutor* e = nullptr;
|
| 166 |
+
ACLNN_CHECK(aclnnAddGetWorkspaceSize(t_x_in.get(), t_attn_out.get(), al, t_x1.get(), &ws, &e));
|
| 167 |
+
DeviceBuffer wb; if (ws > 0) wb.alloc(ws);
|
| 168 |
+
ACLNN_CHECK(aclnnAdd(wb.get(), ws, e, rt_.stream()));
|
| 169 |
+
aclDestroyScalar(al);
|
| 170 |
+
}
|
| 171 |
+
|
| 172 |
+
// MoE
|
| 173 |
+
moe_forward(
|
| 174 |
+
rt_.stream(), cfg_, attn_[layer_idx], moe_[layer_idx],
|
| 175 |
+
x_buf_a_.get(), S,
|
| 176 |
+
moe_xn_.get(), moe_rstd_.get(),
|
| 177 |
+
moe_logits_.get(),
|
| 178 |
+
moe_topk_w_.get(), moe_topk_idx_.get(), moe_row_idx_.get(),
|
| 179 |
+
moe_ex_x_.get(), moe_ex_ri_.get(), moe_tpe_.get(),
|
| 180 |
+
moe_fwd_.get(),
|
| 181 |
+
moe_gate_.get(), moe_up_.get(), moe_down_.get(),
|
| 182 |
+
moe_packed_.get(), moe_weighted_.get(),
|
| 183 |
+
moe_out_.get(),
|
| 184 |
+
(hccl_ctx_.tp_size > 1) ? &hccl_ctx_ : nullptr,
|
| 185 |
+
moe_norm_sum_.get());
|
| 186 |
+
|
| 187 |
+
// x_out = x1 + moe_out (residual)
|
| 188 |
+
auto t_moe_out = make_contig_tensor(moe_out_.get(), ACL_BF16, {S, D});
|
| 189 |
+
auto t_out = make_contig_tensor(x_out, ACL_BF16, {S, D});
|
| 190 |
+
{
|
| 191 |
+
float a = 1.0f; aclScalar* al = aclCreateScalar(&a, ACL_FLOAT);
|
| 192 |
+
uint64_t ws = 0; aclOpExecutor* e = nullptr;
|
| 193 |
+
ACLNN_CHECK(aclnnAddGetWorkspaceSize(t_x1.get(), t_moe_out.get(), al, t_out.get(), &ws, &e));
|
| 194 |
+
DeviceBuffer wb; if (ws > 0) wb.alloc(ws);
|
| 195 |
+
ACLNN_CHECK(aclnnAdd(wb.get(), ws, e, rt_.stream()));
|
| 196 |
+
aclDestroyScalar(al);
|
| 197 |
+
}
|
| 198 |
+
}
|
| 199 |
+
|
| 200 |
+
void Runner::final_logits_(void* hidden_last, DeviceBuffer& logits_out) {
|
| 201 |
+
// Single-position variant: hidden_last is [1, D], output [1, V].
|
| 202 |
+
final_logits_batch_(hidden_last, 1, logits_out);
|
| 203 |
+
}
|
| 204 |
+
|
| 205 |
+
void Runner::final_logits_batch_(void* hidden, int64_t S, DeviceBuffer& logits_out) {
|
| 206 |
+
const int64_t D = cfg_.hidden_size;
|
| 207 |
+
const int64_t V = cfg_.vocab_size;
|
| 208 |
+
|
| 209 |
+
DeviceBuffer hn(S * D * 2), rstd(S * 4);
|
| 210 |
+
auto t_h = make_contig_tensor(hidden, ACL_BF16, {S, D});
|
| 211 |
+
auto t_hn = make_contig_tensor(hn.get(), ACL_BF16, {S, D});
|
| 212 |
+
auto t_lnw = make_contig_tensor(shared_.final_norm.get(), ACL_BF16, {D});
|
| 213 |
+
auto t_rstd = make_contig_tensor(rstd.get(), ACL_FLOAT, {S});
|
| 214 |
+
rms_norm(rt_.stream(), t_h.get(), t_lnw.get(), cfg_.rms_norm_eps, t_hn.get(), t_rstd.get());
|
| 215 |
+
|
| 216 |
+
logits_out.alloc(S * V * 2);
|
| 217 |
+
auto t_logits = make_contig_tensor(logits_out.get(), ACL_BF16, {S, V});
|
| 218 |
+
linear_hf(rt_.stream(), t_hn.get(), shared_.lm_head.get(), ACL_BF16, V, D, t_logits.get());
|
| 219 |
+
}
|
| 220 |
+
|
| 221 |
+
bool Runner::decode_batch(const int32_t* tokens, int64_t S, DeviceBuffer& all_logits_out) {
|
| 222 |
+
if (S < 1) return false;
|
| 223 |
+
if (past_len_ + S > max_seq_) {
|
| 224 |
+
fprintf(stderr, "runner: decode_batch exceeds max_seq (%ld + %ld > %ld)\n",
|
| 225 |
+
past_len_, S, max_seq_);
|
| 226 |
+
return false;
|
| 227 |
+
}
|
| 228 |
+
const int64_t D = cfg_.hidden_size;
|
| 229 |
+
|
| 230 |
+
ensure_all_scratch_(this, S, cfg_,
|
| 231 |
+
q_sc_, k_sc_, v_sc_, xn_sc_, rstd_sc_, rope_sc_, attn_fias_sc_, attn_out_sc_,
|
| 232 |
+
moe_xn_, moe_rstd_, moe_logits_,
|
| 233 |
+
moe_topk_w_, moe_topk_idx_, moe_row_idx_,
|
| 234 |
+
moe_ex_x_, moe_ex_ri_, moe_tpe_,
|
| 235 |
+
moe_fwd_,
|
| 236 |
+
moe_gate_, moe_up_, moe_down_,
|
| 237 |
+
moe_packed_, moe_weighted_, moe_out_,
|
| 238 |
+
moe_norm_sum_,
|
| 239 |
+
x_buf_a_, x_buf_b_);
|
| 240 |
+
|
| 241 |
+
// Embed S tokens
|
| 242 |
+
DeviceBuffer tok_dev(S * 4);
|
| 243 |
+
ACL_CHECK(aclrtMemcpy(tok_dev.get(), S*4, tokens, S*4, ACL_MEMCPY_HOST_TO_DEVICE));
|
| 244 |
+
auto t_tok = make_contig_tensor(tok_dev.get(), ACL_INT32, {S});
|
| 245 |
+
auto t_embed_w = make_contig_tensor(shared_.embed_tokens.get(), ACL_BF16, {cfg_.vocab_size, D});
|
| 246 |
+
|
| 247 |
+
DeviceBuffer x0(S * D * 2);
|
| 248 |
+
auto t_x0 = make_contig_tensor(x0.get(), ACL_BF16, {S, D});
|
| 249 |
+
index_select(rt_.stream(), t_embed_w.get(), 0, t_tok.get(), t_x0.get());
|
| 250 |
+
|
| 251 |
+
DeviceBuffer xping(S * D * 2), xpong(S * D * 2);
|
| 252 |
+
ACL_CHECK(aclrtMemcpyAsync(xping.get(), S*D*2, x0.get(), S*D*2, ACL_MEMCPY_DEVICE_TO_DEVICE, rt_.stream()));
|
| 253 |
+
void* cur_in = xping.get();
|
| 254 |
+
void* cur_out = xpong.get();
|
| 255 |
+
// batch_decode_mode=true uses proper causal-with-past mask (S × past+S, sparse_mode=0).
|
| 256 |
+
for (int L = 0; L < num_layers_; L++) {
|
| 257 |
+
layer_forward_(L, S, cur_in, cur_out, /*batch_decode_mode=*/past_len_ > 0);
|
| 258 |
+
std::swap(cur_in, cur_out);
|
| 259 |
+
}
|
| 260 |
+
rt_.sync();
|
| 261 |
+
|
| 262 |
+
// Get logits for ALL S positions (not just last)
|
| 263 |
+
final_logits_batch_(cur_in, S, all_logits_out);
|
| 264 |
+
rt_.sync();
|
| 265 |
+
|
| 266 |
+
past_len_ += S;
|
| 267 |
+
return true;
|
| 268 |
+
}
|
| 269 |
+
|
| 270 |
+
bool Runner::prefill(const int32_t* tokens, int64_t S, DeviceBuffer& logits_out) {
|
| 271 |
+
if (S < 1) return false;
|
| 272 |
+
if (past_len_ + S > max_seq_) {
|
| 273 |
+
fprintf(stderr, "runner: prefill exceeds max_seq (%ld + %ld > %ld)\n",
|
| 274 |
+
past_len_, S, max_seq_);
|
| 275 |
+
return false;
|
| 276 |
+
}
|
| 277 |
+
|
| 278 |
+
const int64_t D = cfg_.hidden_size;
|
| 279 |
+
ensure_all_scratch_(this, S, cfg_,
|
| 280 |
+
q_sc_, k_sc_, v_sc_, xn_sc_, rstd_sc_, rope_sc_, attn_fias_sc_, attn_out_sc_,
|
| 281 |
+
moe_xn_, moe_rstd_, moe_logits_,
|
| 282 |
+
moe_topk_w_, moe_topk_idx_, moe_row_idx_,
|
| 283 |
+
moe_ex_x_, moe_ex_ri_, moe_tpe_,
|
| 284 |
+
moe_fwd_,
|
| 285 |
+
moe_gate_, moe_up_, moe_down_,
|
| 286 |
+
moe_packed_, moe_weighted_, moe_out_,
|
| 287 |
+
moe_norm_sum_,
|
| 288 |
+
x_buf_a_, x_buf_b_);
|
| 289 |
+
|
| 290 |
+
// Embed
|
| 291 |
+
DeviceBuffer tok_dev(S * 4);
|
| 292 |
+
ACL_CHECK(aclrtMemcpy(tok_dev.get(), S*4, tokens, S*4, ACL_MEMCPY_HOST_TO_DEVICE));
|
| 293 |
+
auto t_tok = make_contig_tensor(tok_dev.get(), ACL_INT32, {S});
|
| 294 |
+
auto t_embed_w = make_contig_tensor(shared_.embed_tokens.get(), ACL_BF16, {cfg_.vocab_size, D});
|
| 295 |
+
|
| 296 |
+
DeviceBuffer x0(S * D * 2);
|
| 297 |
+
auto t_x0 = make_contig_tensor(x0.get(), ACL_BF16, {S, D});
|
| 298 |
+
index_select(rt_.stream(), t_embed_w.get(), 0, t_tok.get(), t_x0.get());
|
| 299 |
+
|
| 300 |
+
// Layer chain: ping-pong between two buffers
|
| 301 |
+
DeviceBuffer xping(S * D * 2), xpong(S * D * 2);
|
| 302 |
+
ACL_CHECK(aclrtMemcpyAsync(xping.get(), S*D*2, x0.get(), S*D*2, ACL_MEMCPY_DEVICE_TO_DEVICE, rt_.stream()));
|
| 303 |
+
|
| 304 |
+
void* cur_in = xping.get();
|
| 305 |
+
void* cur_out = xpong.get();
|
| 306 |
+
for (int L = 0; L < num_layers_; L++) {
|
| 307 |
+
layer_forward_(L, S, cur_in, cur_out);
|
| 308 |
+
std::swap(cur_in, cur_out);
|
| 309 |
+
}
|
| 310 |
+
rt_.sync();
|
| 311 |
+
|
| 312 |
+
// Take last position's hidden → final_logits
|
| 313 |
+
DeviceBuffer last(1 * D * 2);
|
| 314 |
+
ACL_CHECK(aclrtMemcpy(last.get(), 1*D*2,
|
| 315 |
+
(char*)cur_in + (S - 1) * D * 2, 1*D*2,
|
| 316 |
+
ACL_MEMCPY_DEVICE_TO_DEVICE));
|
| 317 |
+
final_logits_(last.get(), logits_out);
|
| 318 |
+
rt_.sync();
|
| 319 |
+
|
| 320 |
+
past_len_ += S;
|
| 321 |
+
return true;
|
| 322 |
+
}
|
| 323 |
+
|
| 324 |
+
bool Runner::decode(int32_t token, DeviceBuffer& logits_out) {
|
| 325 |
+
const int64_t D = cfg_.hidden_size;
|
| 326 |
+
if (past_len_ + 1 > max_seq_) {
|
| 327 |
+
fprintf(stderr, "runner: decode exceeds max_seq\n");
|
| 328 |
+
return false;
|
| 329 |
+
}
|
| 330 |
+
|
| 331 |
+
const int64_t S = 1;
|
| 332 |
+
ensure_all_scratch_(this, S, cfg_,
|
| 333 |
+
q_sc_, k_sc_, v_sc_, xn_sc_, rstd_sc_, rope_sc_, attn_fias_sc_, attn_out_sc_,
|
| 334 |
+
moe_xn_, moe_rstd_, moe_logits_,
|
| 335 |
+
moe_topk_w_, moe_topk_idx_, moe_row_idx_,
|
| 336 |
+
moe_ex_x_, moe_ex_ri_, moe_tpe_,
|
| 337 |
+
moe_fwd_,
|
| 338 |
+
moe_gate_, moe_up_, moe_down_,
|
| 339 |
+
moe_packed_, moe_weighted_, moe_out_,
|
| 340 |
+
moe_norm_sum_,
|
| 341 |
+
x_buf_a_, x_buf_b_);
|
| 342 |
+
|
| 343 |
+
DeviceBuffer tok_dev(1 * 4);
|
| 344 |
+
ACL_CHECK(aclrtMemcpy(tok_dev.get(), 4, &token, 4, ACL_MEMCPY_HOST_TO_DEVICE));
|
| 345 |
+
auto t_tok = make_contig_tensor(tok_dev.get(), ACL_INT32, {1});
|
| 346 |
+
auto t_embed_w = make_contig_tensor(shared_.embed_tokens.get(), ACL_BF16, {cfg_.vocab_size, D});
|
| 347 |
+
|
| 348 |
+
auto t0 = std::chrono::steady_clock::now();
|
| 349 |
+
|
| 350 |
+
DeviceBuffer x0(1 * D * 2);
|
| 351 |
+
auto t_x0 = make_contig_tensor(x0.get(), ACL_BF16, {1, D});
|
| 352 |
+
index_select(rt_.stream(), t_embed_w.get(), 0, t_tok.get(), t_x0.get());
|
| 353 |
+
|
| 354 |
+
DeviceBuffer xping(1 * D * 2), xpong(1 * D * 2);
|
| 355 |
+
ACL_CHECK(aclrtMemcpyAsync(xping.get(), 1*D*2, x0.get(), 1*D*2, ACL_MEMCPY_DEVICE_TO_DEVICE, rt_.stream()));
|
| 356 |
+
if (profile_enabled) { ACL_CHECK(aclrtSynchronizeStream(rt_.stream())); }
|
| 357 |
+
auto t1 = std::chrono::steady_clock::now();
|
| 358 |
+
|
| 359 |
+
void* cur_in = xping.get();
|
| 360 |
+
void* cur_out = xpong.get();
|
| 361 |
+
for (int L = 0; L < num_layers_; L++) {
|
| 362 |
+
layer_forward_(L, 1, cur_in, cur_out);
|
| 363 |
+
std::swap(cur_in, cur_out);
|
| 364 |
+
}
|
| 365 |
+
rt_.sync();
|
| 366 |
+
auto t2 = std::chrono::steady_clock::now();
|
| 367 |
+
|
| 368 |
+
final_logits_(cur_in, logits_out);
|
| 369 |
+
rt_.sync();
|
| 370 |
+
auto t3 = std::chrono::steady_clock::now();
|
| 371 |
+
|
| 372 |
+
if (profile_enabled) {
|
| 373 |
+
using ms = std::chrono::duration<double, std::milli>;
|
| 374 |
+
t_embed_ms += ms(t1 - t0).count();
|
| 375 |
+
t_layers_ms += ms(t2 - t1).count();
|
| 376 |
+
t_final_ms += ms(t3 - t2).count();
|
| 377 |
+
profile_calls++;
|
| 378 |
+
}
|
| 379 |
+
|
| 380 |
+
past_len_ += 1;
|
| 381 |
+
return true;
|
| 382 |
+
}
|
| 383 |
+
|
| 384 |
+
void Runner::build_batch_decode_mask_(int64_t S) {
|
| 385 |
+
int64_t kv_len = past_len_ + S;
|
| 386 |
+
size_t bytes = (size_t)S * kv_len; // bool = 1 byte
|
| 387 |
+
if (batch_mask_dev_.size < bytes) batch_mask_dev_.alloc(bytes);
|
| 388 |
+
std::vector<uint8_t> h_mask(bytes, 0);
|
| 389 |
+
for (int64_t i = 0; i < S; i++) {
|
| 390 |
+
// Row i: positions j ≤ past_len_+i are visible (0), j > past_len_+i are masked (1).
|
| 391 |
+
for (int64_t j = past_len_ + i + 1; j < kv_len; j++) {
|
| 392 |
+
h_mask[i * kv_len + j] = 1;
|
| 393 |
+
}
|
| 394 |
+
}
|
| 395 |
+
ACL_CHECK(aclrtMemcpy(batch_mask_dev_.get(), bytes, h_mask.data(), bytes,
|
| 396 |
+
ACL_MEMCPY_HOST_TO_DEVICE));
|
| 397 |
+
}
|
| 398 |
+
|
| 399 |
+
void Runner::warmup(int iterations) {
|
| 400 |
+
if (num_layers_ == 0) return;
|
| 401 |
+
int64_t saved_past = past_len_;
|
| 402 |
+
past_len_ = 0;
|
| 403 |
+
int32_t dummy_tok = 0; // token id 0, valid for Qwen3 (bos)
|
| 404 |
+
DeviceBuffer dummy_logits;
|
| 405 |
+
for (int i = 0; i < iterations; i++) {
|
| 406 |
+
past_len_ = 0;
|
| 407 |
+
if (!decode(dummy_tok, dummy_logits)) break;
|
| 408 |
+
}
|
| 409 |
+
past_len_ = saved_past;
|
| 410 |
+
fprintf(stderr, "[runner] warmup: %d iterations done\n", iterations);
|
| 411 |
+
}
|
| 412 |
+
|
| 413 |
+
void Runner::print_profile_summary() const {
|
| 414 |
+
if (!profile_enabled || profile_calls == 0) return;
|
| 415 |
+
double total = t_embed_ms + t_layers_ms + t_final_ms;
|
| 416 |
+
fprintf(stderr, "\n=== Runner profile (%ld decode calls) ===\n", profile_calls);
|
| 417 |
+
fprintf(stderr, " phase total_ms avg_ms/call pct\n");
|
| 418 |
+
fprintf(stderr, " embed %8.1f %10.3f %5.1f%%\n",
|
| 419 |
+
t_embed_ms, t_embed_ms / profile_calls, 100.0 * t_embed_ms / total);
|
| 420 |
+
fprintf(stderr, " layers (x%d) %8.1f %10.3f %5.1f%% → %.3f ms/layer/call\n",
|
| 421 |
+
num_layers_, t_layers_ms, t_layers_ms / profile_calls,
|
| 422 |
+
100.0 * t_layers_ms / total,
|
| 423 |
+
t_layers_ms / profile_calls / num_layers_);
|
| 424 |
+
fprintf(stderr, " final+lm_hd %8.1f %10.3f %5.1f%%\n",
|
| 425 |
+
t_final_ms, t_final_ms / profile_calls, 100.0 * t_final_ms / total);
|
| 426 |
+
fprintf(stderr, " total %8.1f %10.3f 100.0%%\n",
|
| 427 |
+
total, total / profile_calls);
|
| 428 |
+
}
|
|
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include "safetensors_loader.h"
|
| 2 |
+
|
| 3 |
+
#include <fcntl.h>
|
| 4 |
+
#include <sys/mman.h>
|
| 5 |
+
#include <sys/stat.h>
|
| 6 |
+
#include <unistd.h>
|
| 7 |
+
|
| 8 |
+
#include <cstdio>
|
| 9 |
+
#include <cstring>
|
| 10 |
+
#include <fstream>
|
| 11 |
+
#include <sstream>
|
| 12 |
+
|
| 13 |
+
#include "json.hpp"
|
| 14 |
+
|
| 15 |
+
using json = nlohmann::json;
|
| 16 |
+
|
| 17 |
+
SafetensorsLoader::SafetensorsLoader() = default;
|
| 18 |
+
|
| 19 |
+
SafetensorsLoader::~SafetensorsLoader() {
|
| 20 |
+
for (auto& s : shards_) {
|
| 21 |
+
if (s.mmap_ptr) munmap(s.mmap_ptr, s.mmap_size);
|
| 22 |
+
if (s.fd >= 0) close(s.fd);
|
| 23 |
+
}
|
| 24 |
+
}
|
| 25 |
+
|
| 26 |
+
bool SafetensorsLoader::open(const std::string& dir) {
|
| 27 |
+
model_dir_ = dir;
|
| 28 |
+
|
| 29 |
+
// 1. Parse index.json to discover shard files
|
| 30 |
+
std::string idx_path = dir + "/model.safetensors.index.json";
|
| 31 |
+
std::ifstream idx_file(idx_path);
|
| 32 |
+
if (!idx_file) {
|
| 33 |
+
// Fallback: single-file model
|
| 34 |
+
std::string single = dir + "/model.safetensors";
|
| 35 |
+
std::ifstream f(single);
|
| 36 |
+
if (!f) {
|
| 37 |
+
fprintf(stderr, "SafetensorsLoader: neither index.json nor model.safetensors found in %s\n", dir.c_str());
|
| 38 |
+
return false;
|
| 39 |
+
}
|
| 40 |
+
shards_.push_back({single});
|
| 41 |
+
return parse_shard_header_(0);
|
| 42 |
+
}
|
| 43 |
+
|
| 44 |
+
json idx;
|
| 45 |
+
try { idx_file >> idx; } catch (std::exception& e) {
|
| 46 |
+
fprintf(stderr, "SafetensorsLoader: bad index.json: %s\n", e.what());
|
| 47 |
+
return false;
|
| 48 |
+
}
|
| 49 |
+
if (!idx.contains("weight_map")) {
|
| 50 |
+
fprintf(stderr, "SafetensorsLoader: index.json missing weight_map\n");
|
| 51 |
+
return false;
|
| 52 |
+
}
|
| 53 |
+
|
| 54 |
+
// Collect unique shard filenames (preserving discovery order).
|
| 55 |
+
std::map<std::string, int> shard_name_to_id;
|
| 56 |
+
for (auto& [name, file] : idx["weight_map"].items()) {
|
| 57 |
+
std::string shard_name = file.get<std::string>();
|
| 58 |
+
if (shard_name_to_id.count(shard_name) == 0) {
|
| 59 |
+
int id = (int)shards_.size();
|
| 60 |
+
shard_name_to_id[shard_name] = id;
|
| 61 |
+
shards_.push_back({dir + "/" + shard_name});
|
| 62 |
+
}
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
// 2. Parse header of each shard to discover tensor offsets
|
| 66 |
+
for (int i = 0; i < (int)shards_.size(); i++) {
|
| 67 |
+
if (!parse_shard_header_(i)) {
|
| 68 |
+
fprintf(stderr, "SafetensorsLoader: failed to parse shard %s\n", shards_[i].path.c_str());
|
| 69 |
+
return false;
|
| 70 |
+
}
|
| 71 |
+
}
|
| 72 |
+
|
| 73 |
+
return true;
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
bool SafetensorsLoader::parse_shard_header_(int shard_id) {
|
| 77 |
+
ShardFile& sh = shards_[shard_id];
|
| 78 |
+
std::ifstream f(sh.path, std::ios::binary);
|
| 79 |
+
if (!f) return false;
|
| 80 |
+
|
| 81 |
+
// Read 8-byte little-endian header length
|
| 82 |
+
uint64_t header_len = 0;
|
| 83 |
+
f.read((char*)&header_len, 8);
|
| 84 |
+
if (!f) return false;
|
| 85 |
+
|
| 86 |
+
std::string header(header_len, '\0');
|
| 87 |
+
f.read(header.data(), header_len);
|
| 88 |
+
if (!f) return false;
|
| 89 |
+
|
| 90 |
+
sh.data_base = 8 + header_len;
|
| 91 |
+
|
| 92 |
+
json j;
|
| 93 |
+
try { j = json::parse(header); } catch (std::exception& e) {
|
| 94 |
+
fprintf(stderr, "SafetensorsLoader: bad shard header JSON in %s: %s\n", sh.path.c_str(), e.what());
|
| 95 |
+
return false;
|
| 96 |
+
}
|
| 97 |
+
|
| 98 |
+
for (auto it = j.begin(); it != j.end(); ++it) {
|
| 99 |
+
const std::string& name = it.key();
|
| 100 |
+
if (name == "__metadata__") continue;
|
| 101 |
+
const auto& entry = it.value();
|
| 102 |
+
|
| 103 |
+
TensorMeta m;
|
| 104 |
+
m.name = name;
|
| 105 |
+
m.dtype = entry["dtype"].get<std::string>();
|
| 106 |
+
for (auto& d : entry["shape"]) m.shape.push_back(d.get<int64_t>());
|
| 107 |
+
const auto& offs = entry["data_offsets"];
|
| 108 |
+
size_t begin = offs[0].get<size_t>();
|
| 109 |
+
size_t end = offs[1].get<size_t>();
|
| 110 |
+
m.offset = sh.data_base + begin;
|
| 111 |
+
m.nbytes = end - begin;
|
| 112 |
+
m.shard_id = shard_id;
|
| 113 |
+
|
| 114 |
+
tensors_[name] = std::move(m);
|
| 115 |
+
}
|
| 116 |
+
|
| 117 |
+
return true;
|
| 118 |
+
}
|
| 119 |
+
|
| 120 |
+
bool SafetensorsLoader::mmap_shard_(int shard_id) {
|
| 121 |
+
ShardFile& sh = shards_[shard_id];
|
| 122 |
+
if (sh.mmap_ptr) return true;
|
| 123 |
+
|
| 124 |
+
sh.fd = ::open(sh.path.c_str(), O_RDONLY);
|
| 125 |
+
if (sh.fd < 0) {
|
| 126 |
+
perror("open");
|
| 127 |
+
return false;
|
| 128 |
+
}
|
| 129 |
+
struct stat st;
|
| 130 |
+
if (fstat(sh.fd, &st) != 0) return false;
|
| 131 |
+
sh.mmap_size = st.st_size;
|
| 132 |
+
|
| 133 |
+
sh.mmap_ptr = mmap(nullptr, sh.mmap_size, PROT_READ, MAP_PRIVATE, sh.fd, 0);
|
| 134 |
+
if (sh.mmap_ptr == MAP_FAILED) {
|
| 135 |
+
perror("mmap");
|
| 136 |
+
sh.mmap_ptr = nullptr;
|
| 137 |
+
return false;
|
| 138 |
+
}
|
| 139 |
+
return true;
|
| 140 |
+
}
|
| 141 |
+
|
| 142 |
+
const TensorMeta* SafetensorsLoader::get(const std::string& name) const {
|
| 143 |
+
auto it = tensors_.find(name);
|
| 144 |
+
if (it == tensors_.end()) return nullptr;
|
| 145 |
+
return &it->second;
|
| 146 |
+
}
|
| 147 |
+
|
| 148 |
+
const void* SafetensorsLoader::data_ptr(const TensorMeta& m) {
|
| 149 |
+
if (m.shard_id < 0 || (size_t)m.shard_id >= shards_.size()) return nullptr;
|
| 150 |
+
if (!mmap_shard_(m.shard_id)) return nullptr;
|
| 151 |
+
ShardFile& sh = shards_[m.shard_id];
|
| 152 |
+
return (const char*)sh.mmap_ptr + m.offset;
|
| 153 |
+
}
|
| 154 |
+
|
| 155 |
+
const void* SafetensorsLoader::data_ptr(const std::string& name) {
|
| 156 |
+
const auto* m = get(name);
|
| 157 |
+
if (!m) return nullptr;
|
| 158 |
+
return data_ptr(*m);
|
| 159 |
+
}
|
| 160 |
+
|
| 161 |
+
std::vector<std::string> SafetensorsLoader::list_tensor_names() const {
|
| 162 |
+
std::vector<std::string> out;
|
| 163 |
+
out.reserve(tensors_.size());
|
| 164 |
+
for (auto& [k, v] : tensors_) out.push_back(k);
|
| 165 |
+
return out;
|
| 166 |
+
}
|
| 167 |
+
|
| 168 |
+
size_t SafetensorsLoader::total_bytes() const {
|
| 169 |
+
size_t sum = 0;
|
| 170 |
+
for (auto& [k, v] : tensors_) sum += v.nbytes;
|
| 171 |
+
return sum;
|
| 172 |
+
}
|
|
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include "tokenizer.h"
|
| 2 |
+
|
| 3 |
+
#include <array>
|
| 4 |
+
#include <cstdint>
|
| 5 |
+
#include <cstdio>
|
| 6 |
+
#include <cstdlib>
|
| 7 |
+
#include <cstring>
|
| 8 |
+
#include <fstream>
|
| 9 |
+
#include <memory>
|
| 10 |
+
#include <sstream>
|
| 11 |
+
#include <unistd.h>
|
| 12 |
+
|
| 13 |
+
bool Tokenizer::load(const std::string& vocab_bin_path) {
|
| 14 |
+
std::ifstream f(vocab_bin_path, std::ios::binary);
|
| 15 |
+
if (!f) {
|
| 16 |
+
fprintf(stderr, "Tokenizer: cannot open %s\n", vocab_bin_path.c_str());
|
| 17 |
+
return false;
|
| 18 |
+
}
|
| 19 |
+
uint32_t num;
|
| 20 |
+
f.read((char*)&num, 4);
|
| 21 |
+
if (!f) return false;
|
| 22 |
+
id_to_bytes_.resize(num);
|
| 23 |
+
for (uint32_t i = 0; i < num; i++) {
|
| 24 |
+
uint32_t len;
|
| 25 |
+
f.read((char*)&len, 4);
|
| 26 |
+
if (!f) return false;
|
| 27 |
+
id_to_bytes_[i].resize(len);
|
| 28 |
+
if (len > 0) f.read(id_to_bytes_[i].data(), len);
|
| 29 |
+
}
|
| 30 |
+
return true;
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
std::string Tokenizer::decode(int id) const {
|
| 34 |
+
if (id < 0 || (size_t)id >= id_to_bytes_.size()) return "";
|
| 35 |
+
return id_to_bytes_[id];
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
std::string Tokenizer::decode(const std::vector<int>& ids) const {
|
| 39 |
+
std::string out;
|
| 40 |
+
for (int id : ids) out += decode(id);
|
| 41 |
+
return out;
|
| 42 |
+
}
|
| 43 |
+
|
| 44 |
+
std::vector<int> Tokenizer::encode_via_python(const std::string& model_dir,
|
| 45 |
+
const std::string& prompt,
|
| 46 |
+
bool apply_chat_template) const {
|
| 47 |
+
// Call python subprocess to tokenize. Embed prompt via stdin to avoid shell-escape bugs.
|
| 48 |
+
std::string cmd;
|
| 49 |
+
// Set QWEN3_PYENV_INIT to override the Python env activation sequence (e.g., "source /opt/my_env/activate && ")
|
| 50 |
+
// Default assumes conda at ~/miniconda3 with env 'qwen3' and Ascend toolkit installed.
|
| 51 |
+
if (const char* init = std::getenv("QWEN3_PYENV_INIT")) {
|
| 52 |
+
cmd += init;
|
| 53 |
+
} else {
|
| 54 |
+
cmd += "source ${HOME}/miniconda3/etc/profile.d/conda.sh 2>/dev/null && ";
|
| 55 |
+
cmd += "conda activate qwen3 2>/dev/null || true; ";
|
| 56 |
+
cmd += "source /usr/local/Ascend/ascend-toolkit/set_env.sh 2>/dev/null || true; ";
|
| 57 |
+
}
|
| 58 |
+
cmd += "python3 -c \"";
|
| 59 |
+
cmd += "import sys, json;";
|
| 60 |
+
cmd += "from transformers import AutoTokenizer;";
|
| 61 |
+
cmd += "t = AutoTokenizer.from_pretrained('" + model_dir + "');";
|
| 62 |
+
cmd += "p = sys.stdin.read();";
|
| 63 |
+
if (apply_chat_template) {
|
| 64 |
+
cmd += "msg = [{'role': 'user', 'content': p}];";
|
| 65 |
+
cmd += "ids = t.apply_chat_template(msg, add_generation_prompt=True);";
|
| 66 |
+
} else {
|
| 67 |
+
cmd += "ids = t.encode(p);";
|
| 68 |
+
}
|
| 69 |
+
cmd += "print(' '.join(str(i) for i in ids));";
|
| 70 |
+
cmd += "\"";
|
| 71 |
+
|
| 72 |
+
// popen with stdin: use the two-pipe dance via temp file for safety
|
| 73 |
+
char tmpl[] = "/tmp/lca_prompt_XXXXXX";
|
| 74 |
+
int fd = mkstemp(tmpl);
|
| 75 |
+
if (fd < 0) { perror("mkstemp"); return {}; }
|
| 76 |
+
write(fd, prompt.data(), prompt.size());
|
| 77 |
+
close(fd);
|
| 78 |
+
|
| 79 |
+
std::string full = cmd + " < " + tmpl + " 2>/dev/null";
|
| 80 |
+
FILE* pipe = popen(full.c_str(), "r");
|
| 81 |
+
if (!pipe) { perror("popen"); unlink(tmpl); return {}; }
|
| 82 |
+
|
| 83 |
+
std::string out;
|
| 84 |
+
char buf[4096];
|
| 85 |
+
while (size_t n = fread(buf, 1, sizeof(buf), pipe)) out.append(buf, n);
|
| 86 |
+
pclose(pipe);
|
| 87 |
+
unlink(tmpl);
|
| 88 |
+
|
| 89 |
+
std::vector<int> ids;
|
| 90 |
+
std::istringstream iss(out);
|
| 91 |
+
int x;
|
| 92 |
+
while (iss >> x) ids.push_back(x);
|
| 93 |
+
return ids;
|
| 94 |
+
}
|
| 95 |
+
|
| 96 |
+
// Shell-quote a string for embedding in a JSON string (escape ", \, control chars).
|
| 97 |
+
static std::string json_escape(const std::string& s) {
|
| 98 |
+
std::string out;
|
| 99 |
+
out.reserve(s.size() + 8);
|
| 100 |
+
for (char c : s) {
|
| 101 |
+
switch (c) {
|
| 102 |
+
case '"': out += "\\\""; break;
|
| 103 |
+
case '\\': out += "\\\\"; break;
|
| 104 |
+
case '\n': out += "\\n"; break;
|
| 105 |
+
case '\r': out += "\\r"; break;
|
| 106 |
+
case '\t': out += "\\t"; break;
|
| 107 |
+
default:
|
| 108 |
+
if ((unsigned char)c < 0x20) {
|
| 109 |
+
char buf[8];
|
| 110 |
+
snprintf(buf, sizeof(buf), "\\u%04x", (unsigned char)c);
|
| 111 |
+
out += buf;
|
| 112 |
+
} else {
|
| 113 |
+
out += c;
|
| 114 |
+
}
|
| 115 |
+
}
|
| 116 |
+
}
|
| 117 |
+
return out;
|
| 118 |
+
}
|
| 119 |
+
|
| 120 |
+
std::vector<int> Tokenizer::encode_conversation_via_python(
|
| 121 |
+
const std::string& model_dir,
|
| 122 |
+
const std::vector<std::pair<std::string, std::string>>& conversation,
|
| 123 |
+
bool add_generation_prompt) const
|
| 124 |
+
{
|
| 125 |
+
// Build JSON array of messages. Pass via stdin to avoid shell-escape issues.
|
| 126 |
+
std::string json_msgs = "[";
|
| 127 |
+
for (size_t i = 0; i < conversation.size(); i++) {
|
| 128 |
+
if (i > 0) json_msgs += ",";
|
| 129 |
+
json_msgs += "{\"role\":\"" + json_escape(conversation[i].first) + "\",";
|
| 130 |
+
json_msgs += "\"content\":\"" + json_escape(conversation[i].second) + "\"}";
|
| 131 |
+
}
|
| 132 |
+
json_msgs += "]";
|
| 133 |
+
|
| 134 |
+
std::string cmd;
|
| 135 |
+
// Set QWEN3_PYENV_INIT to override the Python env activation sequence (e.g., "source /opt/my_env/activate && ")
|
| 136 |
+
// Default assumes conda at ~/miniconda3 with env 'qwen3' and Ascend toolkit installed.
|
| 137 |
+
if (const char* init = std::getenv("QWEN3_PYENV_INIT")) {
|
| 138 |
+
cmd += init;
|
| 139 |
+
} else {
|
| 140 |
+
cmd += "source ${HOME}/miniconda3/etc/profile.d/conda.sh 2>/dev/null && ";
|
| 141 |
+
cmd += "conda activate qwen3 2>/dev/null || true; ";
|
| 142 |
+
cmd += "source /usr/local/Ascend/ascend-toolkit/set_env.sh 2>/dev/null || true; ";
|
| 143 |
+
}
|
| 144 |
+
cmd += "python3 -c \"";
|
| 145 |
+
cmd += "import sys, json;";
|
| 146 |
+
cmd += "from transformers import AutoTokenizer;";
|
| 147 |
+
cmd += "t = AutoTokenizer.from_pretrained('" + model_dir + "');";
|
| 148 |
+
cmd += "msgs = json.loads(sys.stdin.read());";
|
| 149 |
+
cmd += "ids = t.apply_chat_template(msgs, add_generation_prompt=";
|
| 150 |
+
cmd += add_generation_prompt ? "True" : "False";
|
| 151 |
+
cmd += ");";
|
| 152 |
+
cmd += "print(' '.join(str(i) for i in ids));";
|
| 153 |
+
cmd += "\"";
|
| 154 |
+
|
| 155 |
+
char tmpl[] = "/tmp/lca_conv_XXXXXX";
|
| 156 |
+
int fd = mkstemp(tmpl);
|
| 157 |
+
if (fd < 0) { perror("mkstemp"); return {}; }
|
| 158 |
+
write(fd, json_msgs.data(), json_msgs.size());
|
| 159 |
+
close(fd);
|
| 160 |
+
|
| 161 |
+
std::string full = cmd + " < " + tmpl + " 2>/dev/null";
|
| 162 |
+
FILE* pipe = popen(full.c_str(), "r");
|
| 163 |
+
if (!pipe) { perror("popen"); unlink(tmpl); return {}; }
|
| 164 |
+
|
| 165 |
+
std::string out;
|
| 166 |
+
char buf[4096];
|
| 167 |
+
while (size_t n = fread(buf, 1, sizeof(buf), pipe)) out.append(buf, n);
|
| 168 |
+
pclose(pipe);
|
| 169 |
+
unlink(tmpl);
|
| 170 |
+
|
| 171 |
+
std::vector<int> ids;
|
| 172 |
+
std::istringstream iss(out);
|
| 173 |
+
int x;
|
| 174 |
+
while (iss >> x) ids.push_back(x);
|
| 175 |
+
return ids;
|
| 176 |
+
}
|
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// hello_acl.cpp — smoke test: aclInit + device + stream + simple tensor + aclnnAdd
|
| 2 |
+
#include "acl_common.h"
|
| 3 |
+
#include <aclnnop/aclnn_add.h>
|
| 4 |
+
#include <cstdio>
|
| 5 |
+
#include <vector>
|
| 6 |
+
|
| 7 |
+
int main() {
|
| 8 |
+
ACL_CHECK(aclInit(nullptr));
|
| 9 |
+
ACL_CHECK(aclrtSetDevice(0));
|
| 10 |
+
aclrtContext ctx;
|
| 11 |
+
ACL_CHECK(aclrtCreateContext(&ctx, 0));
|
| 12 |
+
aclrtStream stream;
|
| 13 |
+
ACL_CHECK(aclrtCreateStream(&stream));
|
| 14 |
+
|
| 15 |
+
// Tiny test: a = [1, 2, 3, 4] f32, b = [10, 20, 30, 40] f32, out = a + b
|
| 16 |
+
const int64_t N = 4;
|
| 17 |
+
std::vector<float> a_host = {1.0f, 2.0f, 3.0f, 4.0f};
|
| 18 |
+
std::vector<float> b_host = {10.0f, 20.0f, 30.0f, 40.0f};
|
| 19 |
+
std::vector<float> out_host(N, 0.0f);
|
| 20 |
+
|
| 21 |
+
DeviceBuffer a_dev(N * sizeof(float));
|
| 22 |
+
DeviceBuffer b_dev(N * sizeof(float));
|
| 23 |
+
DeviceBuffer out_dev(N * sizeof(float));
|
| 24 |
+
|
| 25 |
+
ACL_CHECK(aclrtMemcpy(a_dev.get(), N * 4, a_host.data(), N * 4, ACL_MEMCPY_HOST_TO_DEVICE));
|
| 26 |
+
ACL_CHECK(aclrtMemcpy(b_dev.get(), N * 4, b_host.data(), N * 4, ACL_MEMCPY_HOST_TO_DEVICE));
|
| 27 |
+
|
| 28 |
+
auto a_t = make_contig_tensor(a_dev.get(), ACL_FLOAT, {N});
|
| 29 |
+
auto b_t = make_contig_tensor(b_dev.get(), ACL_FLOAT, {N});
|
| 30 |
+
auto out_t = make_contig_tensor(out_dev.get(), ACL_FLOAT, {N});
|
| 31 |
+
|
| 32 |
+
// aclnnAdd: out = a + alpha * b
|
| 33 |
+
float alpha_val = 1.0f;
|
| 34 |
+
aclScalar* alpha = aclCreateScalar(&alpha_val, ACL_FLOAT);
|
| 35 |
+
|
| 36 |
+
uint64_t ws_size = 0;
|
| 37 |
+
aclOpExecutor* executor = nullptr;
|
| 38 |
+
ACLNN_CHECK(aclnnAddGetWorkspaceSize(a_t.get(), b_t.get(), alpha, out_t.get(), &ws_size, &executor));
|
| 39 |
+
|
| 40 |
+
DeviceBuffer ws;
|
| 41 |
+
if (ws_size > 0) ws.alloc(ws_size);
|
| 42 |
+
ACLNN_CHECK(aclnnAdd(ws.get(), ws_size, executor, stream));
|
| 43 |
+
|
| 44 |
+
ACL_CHECK(aclrtSynchronizeStream(stream));
|
| 45 |
+
|
| 46 |
+
ACL_CHECK(aclrtMemcpy(out_host.data(), N * 4, out_dev.get(), N * 4, ACL_MEMCPY_DEVICE_TO_HOST));
|
| 47 |
+
|
| 48 |
+
printf("hello_acl: ");
|
| 49 |
+
for (int i = 0; i < N; i++) printf("%.1f ", out_host[i]);
|
| 50 |
+
printf("\n");
|
| 51 |
+
|
| 52 |
+
bool ok = (out_host[0] == 11.0f && out_host[1] == 22.0f &&
|
| 53 |
+
out_host[2] == 33.0f && out_host[3] == 44.0f);
|
| 54 |
+
printf(ok ? "PASS\n" : "FAIL\n");
|
| 55 |
+
|
| 56 |
+
aclDestroyScalar(alpha);
|
| 57 |
+
ACL_CHECK(aclrtDestroyStream(stream));
|
| 58 |
+
ACL_CHECK(aclrtDestroyContext(ctx));
|
| 59 |
+
ACL_CHECK(aclrtResetDevice(0));
|
| 60 |
+
aclFinalize();
|
| 61 |
+
return ok ? 0 : 1;
|
| 62 |
+
}
|
|
@@ -0,0 +1,319 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// test_attention_decode.cpp — validates single-layer attention with KV cache.
|
| 2 |
+
//
|
| 3 |
+
// Strategy: compare two paths yielding the same pos-5 attention output:
|
| 4 |
+
// Path A (reference): prefill 6 tokens in one shot → attn_out[5]
|
| 5 |
+
// Path B (decode): prefill 5 tokens → K/V cache; decode 6th token via cache → attn_out_decode[0]
|
| 6 |
+
//
|
| 7 |
+
// The two outputs should match within BF16 precision.
|
| 8 |
+
#include "acl_common.h"
|
| 9 |
+
#include "acl_runtime.h"
|
| 10 |
+
#include "aclnn_ops.h"
|
| 11 |
+
#include "device_weights.h"
|
| 12 |
+
#include "model_config.h"
|
| 13 |
+
#include "rope.h"
|
| 14 |
+
#include "safetensors_loader.h"
|
| 15 |
+
|
| 16 |
+
#include <cmath>
|
| 17 |
+
#include <cstdio>
|
| 18 |
+
#include <cstring>
|
| 19 |
+
#include <fstream>
|
| 20 |
+
#include <vector>
|
| 21 |
+
|
| 22 |
+
static float bf16_to_float(uint16_t x) {
|
| 23 |
+
uint32_t u = (uint32_t)x << 16; float f; std::memcpy(&f, &u, 4); return f;
|
| 24 |
+
}
|
| 25 |
+
static uint16_t float_to_bf16(float x) {
|
| 26 |
+
uint32_t u; std::memcpy(&u, &x, 4);
|
| 27 |
+
return (uint16_t)((u + 0x7FFF + ((u >> 16) & 1)) >> 16);
|
| 28 |
+
}
|
| 29 |
+
static std::vector<uint8_t> read_file(const std::string& p) {
|
| 30 |
+
std::ifstream f(p, std::ios::binary | std::ios::ate); size_t s = f.tellg();
|
| 31 |
+
f.seekg(0); std::vector<uint8_t> v(s); f.read((char*)v.data(), s); return v;
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
// Fill cos/sin tables for a range of positions [p0, p0+L). HF layout: half-half.
|
| 35 |
+
static void fill_cos_sin(std::vector<uint16_t>& cos_h, std::vector<uint16_t>& sin_h,
|
| 36 |
+
int64_t p0, int64_t L, int64_t Dh, float theta) {
|
| 37 |
+
cos_h.resize(L * Dh); sin_h.resize(L * Dh);
|
| 38 |
+
int64_t half = Dh / 2;
|
| 39 |
+
for (int64_t s = 0; s < L; s++) {
|
| 40 |
+
for (int64_t d = 0; d < Dh; d++) {
|
| 41 |
+
int64_t pair = (d < half) ? d : (d - half);
|
| 42 |
+
float theta_pair = 1.0f / std::pow(theta, (2.0f * pair) / Dh);
|
| 43 |
+
float angle = (float)(p0 + s) * theta_pair;
|
| 44 |
+
cos_h[s * Dh + d] = float_to_bf16(std::cos(angle));
|
| 45 |
+
sin_h[s * Dh + d] = float_to_bf16(std::sin(angle));
|
| 46 |
+
}
|
| 47 |
+
}
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
+
int main() {
|
| 51 |
+
const std::string model_dir = "/path/to/Qwen3-235B-A22B-Instruct-2507-BF16";
|
| 52 |
+
const std::string data_dir = "tests/attn_data";
|
| 53 |
+
|
| 54 |
+
ModelConfig cfg;
|
| 55 |
+
if (!cfg.load_from_json(model_dir + "/config.json")) return 1;
|
| 56 |
+
cfg.compute_derived(1, 0);
|
| 57 |
+
const int64_t D = cfg.hidden_size;
|
| 58 |
+
const int64_t Hq = cfg.num_attention_heads;
|
| 59 |
+
const int64_t Hkv = cfg.num_key_value_heads;
|
| 60 |
+
const int64_t Dh = cfg.head_dim;
|
| 61 |
+
const int64_t Q_DIM = Hq * Dh;
|
| 62 |
+
const int64_t KV_DIM = Hkv * Dh;
|
| 63 |
+
const double scale = 1.0 / std::sqrt((double)Dh);
|
| 64 |
+
const double eps = cfg.rms_norm_eps;
|
| 65 |
+
const float theta = cfg.rope_theta;
|
| 66 |
+
|
| 67 |
+
SafetensorsLoader st;
|
| 68 |
+
if (!st.open(model_dir)) return 1;
|
| 69 |
+
AclRuntime rt;
|
| 70 |
+
rt.init(0);
|
| 71 |
+
|
| 72 |
+
DeviceWeightsLoader dw(st, cfg);
|
| 73 |
+
SharedWeights shared;
|
| 74 |
+
LayerAttnWeights attn;
|
| 75 |
+
printf("Loading weights...\n");
|
| 76 |
+
if (!dw.load_shared(shared)) return 1;
|
| 77 |
+
if (!dw.load_attention(0, attn)) return 1;
|
| 78 |
+
|
| 79 |
+
// ---- Load 5 prefill tokens + use token[5]=random as "6th" decoded token ----
|
| 80 |
+
auto tok_raw = read_file(data_dir + "/token_ids.bin");
|
| 81 |
+
int32_t S_prefill = *(int32_t*)tok_raw.data();
|
| 82 |
+
if (S_prefill < 5) { fprintf(stderr, "need >=5 tokens\n"); return 1; }
|
| 83 |
+
std::vector<int32_t> tokens(S_prefill);
|
| 84 |
+
std::memcpy(tokens.data(), tok_raw.data() + 4, S_prefill * 4);
|
| 85 |
+
|
| 86 |
+
// Build 6-token sequence (reuse first 5; pick a 6th token id — use token 0 as a simple choice)
|
| 87 |
+
const int64_t S6 = 6;
|
| 88 |
+
const int64_t S5 = 5;
|
| 89 |
+
std::vector<int32_t> tok6(S6);
|
| 90 |
+
for (int i = 0; i < S5; i++) tok6[i] = tokens[i];
|
| 91 |
+
tok6[5] = tokens[0]; // any token works for cross-consistency test
|
| 92 |
+
printf("tokens6=["); for (auto t : tok6) printf("%d,", t); printf("]\n");
|
| 93 |
+
|
| 94 |
+
// ---- Causal mask (2048x2048, sparse_mode=3) shared across both paths ----
|
| 95 |
+
const int64_t MASK = 2048;
|
| 96 |
+
DeviceBuffer mask_dev(MASK * MASK);
|
| 97 |
+
std::vector<uint8_t> mask_host(MASK * MASK, 0);
|
| 98 |
+
for (int i = 0; i < MASK; i++)
|
| 99 |
+
for (int j = i+1; j < MASK; j++)
|
| 100 |
+
mask_host[i*MASK + j] = 1;
|
| 101 |
+
ACL_CHECK(aclrtMemcpy(mask_dev.get(), MASK*MASK, mask_host.data(), MASK*MASK, ACL_MEMCPY_HOST_TO_DEVICE));
|
| 102 |
+
auto t_mask = make_contig_tensor(mask_dev.get(), ACL_BOOL, {1, 1, MASK, MASK});
|
| 103 |
+
|
| 104 |
+
// =========================================================================
|
| 105 |
+
// PATH A: 6-token prefill (reference)
|
| 106 |
+
// =========================================================================
|
| 107 |
+
printf("\n[Path A] 6-token prefill reference\n");
|
| 108 |
+
|
| 109 |
+
DeviceBuffer tokA_dev(S6 * 4);
|
| 110 |
+
ACL_CHECK(aclrtMemcpy(tokA_dev.get(), S6*4, tok6.data(), S6*4, ACL_MEMCPY_HOST_TO_DEVICE));
|
| 111 |
+
auto t_tokA = make_contig_tensor(tokA_dev.get(), ACL_INT32, {S6});
|
| 112 |
+
auto t_embed_w = make_contig_tensor(shared.embed_tokens.get(), ACL_BF16, {cfg.vocab_size, D});
|
| 113 |
+
|
| 114 |
+
DeviceBuffer xA_dev(S6 * D * 2);
|
| 115 |
+
auto t_xA = make_contig_tensor(xA_dev.get(), ACL_BF16, {S6, D});
|
| 116 |
+
index_select(rt.stream(), t_embed_w.get(), 0, t_tokA.get(), t_xA.get());
|
| 117 |
+
rt.sync();
|
| 118 |
+
|
| 119 |
+
DeviceBuffer xnA_dev(S6 * D * 2);
|
| 120 |
+
DeviceBuffer rstdA_dev(S6 * 4);
|
| 121 |
+
auto t_xnA = make_contig_tensor(xnA_dev.get(), ACL_BF16, {S6, D});
|
| 122 |
+
auto t_ln_w = make_contig_tensor(attn.input_layernorm.get(), ACL_BF16, {D});
|
| 123 |
+
auto t_rstdA = make_contig_tensor(rstdA_dev.get(), ACL_FLOAT, {S6});
|
| 124 |
+
rms_norm(rt.stream(), t_xA.get(), t_ln_w.get(), eps, t_xnA.get(), t_rstdA.get());
|
| 125 |
+
|
| 126 |
+
DeviceBuffer qA_dev(S6 * Q_DIM * 2);
|
| 127 |
+
DeviceBuffer kA_dev(S6 * KV_DIM * 2);
|
| 128 |
+
DeviceBuffer vA_dev(S6 * KV_DIM * 2);
|
| 129 |
+
auto t_qA = make_contig_tensor(qA_dev.get(), ACL_BF16, {S6, Q_DIM});
|
| 130 |
+
auto t_kA = make_contig_tensor(kA_dev.get(), ACL_BF16, {S6, KV_DIM});
|
| 131 |
+
auto t_vA = make_contig_tensor(vA_dev.get(), ACL_BF16, {S6, KV_DIM});
|
| 132 |
+
linear_hf(rt.stream(), t_xnA.get(), attn.q_proj.get(), ACL_BF16, Q_DIM, D, t_qA.get());
|
| 133 |
+
linear_hf(rt.stream(), t_xnA.get(), attn.k_proj.get(), ACL_BF16, KV_DIM, D, t_kA.get());
|
| 134 |
+
linear_hf(rt.stream(), t_xnA.get(), attn.v_proj.get(), ACL_BF16, KV_DIM, D, t_vA.get());
|
| 135 |
+
|
| 136 |
+
// Per-head norm
|
| 137 |
+
auto t_qA_4d = make_contig_tensor(qA_dev.get(), ACL_BF16, {1, S6, Hq, Dh});
|
| 138 |
+
auto t_kA_4d = make_contig_tensor(kA_dev.get(), ACL_BF16, {1, S6, Hkv, Dh});
|
| 139 |
+
auto t_qn_w = make_contig_tensor(attn.q_norm.get(), ACL_BF16, {Dh});
|
| 140 |
+
auto t_kn_w = make_contig_tensor(attn.k_norm.get(), ACL_BF16, {Dh});
|
| 141 |
+
DeviceBuffer rstd_qA(S6 * Hq * 4), rstd_kA(S6 * Hkv * 4);
|
| 142 |
+
auto t_rstd_qA = make_contig_tensor(rstd_qA.get(), ACL_FLOAT, {1, S6, Hq});
|
| 143 |
+
auto t_rstd_kA = make_contig_tensor(rstd_kA.get(), ACL_FLOAT, {1, S6, Hkv});
|
| 144 |
+
rms_norm(rt.stream(), t_qA_4d.get(), t_qn_w.get(), eps, t_qA_4d.get(), t_rstd_qA.get());
|
| 145 |
+
rms_norm(rt.stream(), t_kA_4d.get(), t_kn_w.get(), eps, t_kA_4d.get(), t_rstd_kA.get());
|
| 146 |
+
|
| 147 |
+
// RoPE for positions 0..5
|
| 148 |
+
std::vector<uint16_t> cosA_h, sinA_h;
|
| 149 |
+
fill_cos_sin(cosA_h, sinA_h, 0, S6, Dh, theta);
|
| 150 |
+
DeviceBuffer cosA_dev(S6 * Dh * 2), sinA_dev(S6 * Dh * 2);
|
| 151 |
+
ACL_CHECK(aclrtMemcpy(cosA_dev.get(), S6*Dh*2, cosA_h.data(), S6*Dh*2, ACL_MEMCPY_HOST_TO_DEVICE));
|
| 152 |
+
ACL_CHECK(aclrtMemcpy(sinA_dev.get(), S6*Dh*2, sinA_h.data(), S6*Dh*2, ACL_MEMCPY_HOST_TO_DEVICE));
|
| 153 |
+
DeviceBuffer ropeA_scratch(1 * S6 * Hq * Dh * 2);
|
| 154 |
+
apply_rope_manual(rt.stream(), qA_dev.get(), 1, S6, Hq, Dh, kA_dev.get(), Hkv,
|
| 155 |
+
cosA_dev.get(), sinA_dev.get(), ropeA_scratch.get());
|
| 156 |
+
|
| 157 |
+
auto t_qA_bsh = make_contig_tensor(qA_dev.get(), ACL_BF16, {1, S6, Q_DIM});
|
| 158 |
+
auto t_kA_bsh = make_contig_tensor(kA_dev.get(), ACL_BF16, {1, S6, KV_DIM});
|
| 159 |
+
auto t_vA_bsh = make_contig_tensor(vA_dev.get(), ACL_BF16, {1, S6, KV_DIM});
|
| 160 |
+
|
| 161 |
+
DeviceBuffer attnA_out(1 * S6 * Q_DIM * 2);
|
| 162 |
+
auto t_attnA_out = make_contig_tensor(attnA_out.get(), ACL_BF16, {1, S6, Q_DIM});
|
| 163 |
+
fused_infer_attention_score(
|
| 164 |
+
rt.stream(), t_qA_bsh.get(), t_kA_bsh.get(), t_vA_bsh.get(),
|
| 165 |
+
t_mask.get(), {S6}, {S6}, Hq, Hkv, scale, 3, t_attnA_out.get());
|
| 166 |
+
rt.sync();
|
| 167 |
+
|
| 168 |
+
// Extract attnA_out[pos=5] into [1, 1, Q_DIM] for comparison
|
| 169 |
+
std::vector<uint16_t> refA_host(Q_DIM);
|
| 170 |
+
ACL_CHECK(aclrtMemcpy(refA_host.data(), Q_DIM*2,
|
| 171 |
+
(char*)attnA_out.get() + 5 * Q_DIM * 2, Q_DIM*2,
|
| 172 |
+
ACL_MEMCPY_DEVICE_TO_HOST));
|
| 173 |
+
printf(" attnA_out[5, :4] = %.5f %.5f %.5f %.5f\n",
|
| 174 |
+
bf16_to_float(refA_host[0]), bf16_to_float(refA_host[1]),
|
| 175 |
+
bf16_to_float(refA_host[2]), bf16_to_float(refA_host[3]));
|
| 176 |
+
|
| 177 |
+
// =========================================================================
|
| 178 |
+
// PATH B: 5-token prefill + KV cache → 1-token decode
|
| 179 |
+
// =========================================================================
|
| 180 |
+
printf("\n[Path B] 5-prefill + 1-decode via KV cache\n");
|
| 181 |
+
|
| 182 |
+
const int64_t MAX_LEN = 128; // small cache for test
|
| 183 |
+
DeviceBuffer k_cache(MAX_LEN * KV_DIM * 2);
|
| 184 |
+
DeviceBuffer v_cache(MAX_LEN * KV_DIM * 2);
|
| 185 |
+
// Zero-init unused slots (not strictly needed, FIAS uses actual_seq_lens).
|
| 186 |
+
|
| 187 |
+
// ---- Prefill 5 tokens ----
|
| 188 |
+
DeviceBuffer tokB_dev(S5 * 4);
|
| 189 |
+
ACL_CHECK(aclrtMemcpy(tokB_dev.get(), S5*4, tok6.data(), S5*4, ACL_MEMCPY_HOST_TO_DEVICE));
|
| 190 |
+
auto t_tokB = make_contig_tensor(tokB_dev.get(), ACL_INT32, {S5});
|
| 191 |
+
DeviceBuffer xB_dev(S5 * D * 2);
|
| 192 |
+
auto t_xB = make_contig_tensor(xB_dev.get(), ACL_BF16, {S5, D});
|
| 193 |
+
index_select(rt.stream(), t_embed_w.get(), 0, t_tokB.get(), t_xB.get());
|
| 194 |
+
rt.sync();
|
| 195 |
+
|
| 196 |
+
DeviceBuffer xnB_dev(S5 * D * 2);
|
| 197 |
+
DeviceBuffer rstdB_dev(S5 * 4);
|
| 198 |
+
auto t_xnB = make_contig_tensor(xnB_dev.get(), ACL_BF16, {S5, D});
|
| 199 |
+
auto t_rstdB = make_contig_tensor(rstdB_dev.get(), ACL_FLOAT, {S5});
|
| 200 |
+
rms_norm(rt.stream(), t_xB.get(), t_ln_w.get(), eps, t_xnB.get(), t_rstdB.get());
|
| 201 |
+
|
| 202 |
+
DeviceBuffer qB_dev(S5 * Q_DIM * 2);
|
| 203 |
+
DeviceBuffer kB_dev(S5 * KV_DIM * 2);
|
| 204 |
+
DeviceBuffer vB_dev(S5 * KV_DIM * 2);
|
| 205 |
+
auto t_qB = make_contig_tensor(qB_dev.get(), ACL_BF16, {S5, Q_DIM});
|
| 206 |
+
auto t_kB = make_contig_tensor(kB_dev.get(), ACL_BF16, {S5, KV_DIM});
|
| 207 |
+
auto t_vB = make_contig_tensor(vB_dev.get(), ACL_BF16, {S5, KV_DIM});
|
| 208 |
+
linear_hf(rt.stream(), t_xnB.get(), attn.q_proj.get(), ACL_BF16, Q_DIM, D, t_qB.get());
|
| 209 |
+
linear_hf(rt.stream(), t_xnB.get(), attn.k_proj.get(), ACL_BF16, KV_DIM, D, t_kB.get());
|
| 210 |
+
linear_hf(rt.stream(), t_xnB.get(), attn.v_proj.get(), ACL_BF16, KV_DIM, D, t_vB.get());
|
| 211 |
+
|
| 212 |
+
auto t_qB_4d = make_contig_tensor(qB_dev.get(), ACL_BF16, {1, S5, Hq, Dh});
|
| 213 |
+
auto t_kB_4d = make_contig_tensor(kB_dev.get(), ACL_BF16, {1, S5, Hkv, Dh});
|
| 214 |
+
DeviceBuffer rstd_qB(S5 * Hq * 4), rstd_kB(S5 * Hkv * 4);
|
| 215 |
+
auto t_rstd_qB = make_contig_tensor(rstd_qB.get(), ACL_FLOAT, {1, S5, Hq});
|
| 216 |
+
auto t_rstd_kB = make_contig_tensor(rstd_kB.get(), ACL_FLOAT, {1, S5, Hkv});
|
| 217 |
+
rms_norm(rt.stream(), t_qB_4d.get(), t_qn_w.get(), eps, t_qB_4d.get(), t_rstd_qB.get());
|
| 218 |
+
rms_norm(rt.stream(), t_kB_4d.get(), t_kn_w.get(), eps, t_kB_4d.get(), t_rstd_kB.get());
|
| 219 |
+
|
| 220 |
+
std::vector<uint16_t> cosB_h, sinB_h;
|
| 221 |
+
fill_cos_sin(cosB_h, sinB_h, 0, S5, Dh, theta);
|
| 222 |
+
DeviceBuffer cosB_dev(S5 * Dh * 2), sinB_dev(S5 * Dh * 2);
|
| 223 |
+
ACL_CHECK(aclrtMemcpy(cosB_dev.get(), S5*Dh*2, cosB_h.data(), S5*Dh*2, ACL_MEMCPY_HOST_TO_DEVICE));
|
| 224 |
+
ACL_CHECK(aclrtMemcpy(sinB_dev.get(), S5*Dh*2, sinB_h.data(), S5*Dh*2, ACL_MEMCPY_HOST_TO_DEVICE));
|
| 225 |
+
DeviceBuffer ropeB_scratch(1 * S5 * Hq * Dh * 2);
|
| 226 |
+
apply_rope_manual(rt.stream(), qB_dev.get(), 1, S5, Hq, Dh, kB_dev.get(), Hkv,
|
| 227 |
+
cosB_dev.get(), sinB_dev.get(), ropeB_scratch.get());
|
| 228 |
+
rt.sync();
|
| 229 |
+
|
| 230 |
+
// Append K, V to cache at positions 0..4.
|
| 231 |
+
ACL_CHECK(aclrtMemcpy(k_cache.get(), S5 * KV_DIM * 2,
|
| 232 |
+
kB_dev.get(), S5 * KV_DIM * 2, ACL_MEMCPY_DEVICE_TO_DEVICE));
|
| 233 |
+
ACL_CHECK(aclrtMemcpy(v_cache.get(), S5 * KV_DIM * 2,
|
| 234 |
+
vB_dev.get(), S5 * KV_DIM * 2, ACL_MEMCPY_DEVICE_TO_DEVICE));
|
| 235 |
+
printf(" cached K/V at positions 0..%ld\n", S5 - 1);
|
| 236 |
+
|
| 237 |
+
// ---- Decode 1 token (position = 5) ----
|
| 238 |
+
DeviceBuffer tokD_dev(1 * 4);
|
| 239 |
+
int32_t tok_dec = tok6[5];
|
| 240 |
+
ACL_CHECK(aclrtMemcpy(tokD_dev.get(), 4, &tok_dec, 4, ACL_MEMCPY_HOST_TO_DEVICE));
|
| 241 |
+
auto t_tokD = make_contig_tensor(tokD_dev.get(), ACL_INT32, {1});
|
| 242 |
+
DeviceBuffer xD_dev(1 * D * 2);
|
| 243 |
+
auto t_xD = make_contig_tensor(xD_dev.get(), ACL_BF16, {1, D});
|
| 244 |
+
index_select(rt.stream(), t_embed_w.get(), 0, t_tokD.get(), t_xD.get());
|
| 245 |
+
|
| 246 |
+
DeviceBuffer xnD_dev(1 * D * 2), rstdD_dev(1 * 4);
|
| 247 |
+
auto t_xnD = make_contig_tensor(xnD_dev.get(), ACL_BF16, {1, D});
|
| 248 |
+
auto t_rstdD = make_contig_tensor(rstdD_dev.get(), ACL_FLOAT, {1});
|
| 249 |
+
rms_norm(rt.stream(), t_xD.get(), t_ln_w.get(), eps, t_xnD.get(), t_rstdD.get());
|
| 250 |
+
|
| 251 |
+
DeviceBuffer qD_dev(1 * Q_DIM * 2), kD_dev(1 * KV_DIM * 2), vD_dev(1 * KV_DIM * 2);
|
| 252 |
+
auto t_qD = make_contig_tensor(qD_dev.get(), ACL_BF16, {1, Q_DIM});
|
| 253 |
+
auto t_kD = make_contig_tensor(kD_dev.get(), ACL_BF16, {1, KV_DIM});
|
| 254 |
+
auto t_vD = make_contig_tensor(vD_dev.get(), ACL_BF16, {1, KV_DIM});
|
| 255 |
+
linear_hf(rt.stream(), t_xnD.get(), attn.q_proj.get(), ACL_BF16, Q_DIM, D, t_qD.get());
|
| 256 |
+
linear_hf(rt.stream(), t_xnD.get(), attn.k_proj.get(), ACL_BF16, KV_DIM, D, t_kD.get());
|
| 257 |
+
linear_hf(rt.stream(), t_xnD.get(), attn.v_proj.get(), ACL_BF16, KV_DIM, D, t_vD.get());
|
| 258 |
+
|
| 259 |
+
auto t_qD_4d = make_contig_tensor(qD_dev.get(), ACL_BF16, {1, 1, Hq, Dh});
|
| 260 |
+
auto t_kD_4d = make_contig_tensor(kD_dev.get(), ACL_BF16, {1, 1, Hkv, Dh});
|
| 261 |
+
DeviceBuffer rstd_qD(1 * Hq * 4), rstd_kD(1 * Hkv * 4);
|
| 262 |
+
auto t_rstd_qD = make_contig_tensor(rstd_qD.get(), ACL_FLOAT, {1, 1, Hq});
|
| 263 |
+
auto t_rstd_kD = make_contig_tensor(rstd_kD.get(), ACL_FLOAT, {1, 1, Hkv});
|
| 264 |
+
rms_norm(rt.stream(), t_qD_4d.get(), t_qn_w.get(), eps, t_qD_4d.get(), t_rstd_qD.get());
|
| 265 |
+
rms_norm(rt.stream(), t_kD_4d.get(), t_kn_w.get(), eps, t_kD_4d.get(), t_rstd_kD.get());
|
| 266 |
+
|
| 267 |
+
// RoPE for position 5 only
|
| 268 |
+
std::vector<uint16_t> cosD_h, sinD_h;
|
| 269 |
+
fill_cos_sin(cosD_h, sinD_h, /*p0=*/5, /*L=*/1, Dh, theta);
|
| 270 |
+
DeviceBuffer cosD_dev(1 * Dh * 2), sinD_dev(1 * Dh * 2);
|
| 271 |
+
ACL_CHECK(aclrtMemcpy(cosD_dev.get(), Dh*2, cosD_h.data(), Dh*2, ACL_MEMCPY_HOST_TO_DEVICE));
|
| 272 |
+
ACL_CHECK(aclrtMemcpy(sinD_dev.get(), Dh*2, sinD_h.data(), Dh*2, ACL_MEMCPY_HOST_TO_DEVICE));
|
| 273 |
+
DeviceBuffer ropeD_scratch(1 * 1 * Hq * Dh * 2);
|
| 274 |
+
apply_rope_manual(rt.stream(), qD_dev.get(), 1, 1, Hq, Dh, kD_dev.get(), Hkv,
|
| 275 |
+
cosD_dev.get(), sinD_dev.get(), ropeD_scratch.get());
|
| 276 |
+
rt.sync();
|
| 277 |
+
|
| 278 |
+
// Append K, V to cache at position 5.
|
| 279 |
+
ACL_CHECK(aclrtMemcpy((char*)k_cache.get() + S5 * KV_DIM * 2, KV_DIM * 2,
|
| 280 |
+
kD_dev.get(), KV_DIM * 2, ACL_MEMCPY_DEVICE_TO_DEVICE));
|
| 281 |
+
ACL_CHECK(aclrtMemcpy((char*)v_cache.get() + S5 * KV_DIM * 2, KV_DIM * 2,
|
| 282 |
+
vD_dev.get(), KV_DIM * 2, ACL_MEMCPY_DEVICE_TO_DEVICE));
|
| 283 |
+
|
| 284 |
+
// FIAS decode: q [1, 1, Q_DIM], k/v [1, 6, KV_DIM] from cache.
|
| 285 |
+
auto t_qD_bsh = make_contig_tensor(qD_dev.get(), ACL_BF16, {1, 1, Q_DIM});
|
| 286 |
+
auto t_kC_bsh = make_contig_tensor(k_cache.get(), ACL_BF16, {1, S6, KV_DIM});
|
| 287 |
+
auto t_vC_bsh = make_contig_tensor(v_cache.get(), ACL_BF16, {1, S6, KV_DIM});
|
| 288 |
+
|
| 289 |
+
DeviceBuffer attnD_out(1 * 1 * Q_DIM * 2);
|
| 290 |
+
auto t_attnD_out = make_contig_tensor(attnD_out.get(), ACL_BF16, {1, 1, Q_DIM});
|
| 291 |
+
// Decode: q has 1 token, k/v have 6 tokens. Use sparse_mode=0 with no mask — the single q
|
| 292 |
+
// at the end can attend to all cached positions; there's no causal constraint on it.
|
| 293 |
+
fused_infer_attention_score(
|
| 294 |
+
rt.stream(), t_qD_bsh.get(), t_kC_bsh.get(), t_vC_bsh.get(),
|
| 295 |
+
nullptr, {1}, {S6},
|
| 296 |
+
Hq, Hkv, scale, 0, t_attnD_out.get());
|
| 297 |
+
rt.sync();
|
| 298 |
+
|
| 299 |
+
std::vector<uint16_t> decB_host(Q_DIM);
|
| 300 |
+
ACL_CHECK(aclrtMemcpy(decB_host.data(), Q_DIM*2, attnD_out.get(), Q_DIM*2, ACL_MEMCPY_DEVICE_TO_HOST));
|
| 301 |
+
|
| 302 |
+
// ---- Compare Path A vs Path B ----
|
| 303 |
+
printf("\n attnB_decode[:4] = %.5f %.5f %.5f %.5f\n",
|
| 304 |
+
bf16_to_float(decB_host[0]), bf16_to_float(decB_host[1]),
|
| 305 |
+
bf16_to_float(decB_host[2]), bf16_to_float(decB_host[3]));
|
| 306 |
+
|
| 307 |
+
double l2d = 0, l2r = 0, maxd = 0;
|
| 308 |
+
for (int i = 0; i < Q_DIM; i++) {
|
| 309 |
+
float a = bf16_to_float(decB_host[i]), b = bf16_to_float(refA_host[i]);
|
| 310 |
+
l2d += (a-b)*(a-b); l2r += b*b;
|
| 311 |
+
if (std::abs(a-b) > maxd) maxd = std::abs(a-b);
|
| 312 |
+
}
|
| 313 |
+
double rel = std::sqrt(l2d) / (std::sqrt(l2r) + 1e-10);
|
| 314 |
+
printf("\nDecode vs 6-prefill comparison: rel=%.4e max_abs=%.4f\n", rel, maxd);
|
| 315 |
+
|
| 316 |
+
bool pass = rel < 5e-2;
|
| 317 |
+
printf("\n%s\n", pass ? "=== test_attention_decode PASS ===" : "=== test_attention_decode FAIL ===");
|
| 318 |
+
return pass ? 0 : 1;
|
| 319 |
+
}
|
|
@@ -0,0 +1,219 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// test_attention_layer.cpp — full single-layer attention forward (Qwen3-235B layer 0), TP=1.
|
| 2 |
+
// Validates C++ output against Python HF-style reference (attn_data/final_out.bin).
|
| 3 |
+
#include "acl_common.h"
|
| 4 |
+
#include "acl_runtime.h"
|
| 5 |
+
#include "aclnn_ops.h"
|
| 6 |
+
#include "device_weights.h"
|
| 7 |
+
#include "model_config.h"
|
| 8 |
+
#include "rope.h"
|
| 9 |
+
#include "safetensors_loader.h"
|
| 10 |
+
|
| 11 |
+
#include <cmath>
|
| 12 |
+
#include <cstdio>
|
| 13 |
+
#include <cstring>
|
| 14 |
+
#include <fstream>
|
| 15 |
+
#include <vector>
|
| 16 |
+
|
| 17 |
+
static float bf16_to_float(uint16_t x) {
|
| 18 |
+
uint32_t u = (uint32_t)x << 16; float f; std::memcpy(&f, &u, 4); return f;
|
| 19 |
+
}
|
| 20 |
+
static uint16_t float_to_bf16(float x) {
|
| 21 |
+
uint32_t u; std::memcpy(&u, &x, 4);
|
| 22 |
+
return (uint16_t)((u + 0x7FFF + ((u >> 16) & 1)) >> 16);
|
| 23 |
+
}
|
| 24 |
+
|
| 25 |
+
static std::vector<uint8_t> read_file(const std::string& p) {
|
| 26 |
+
std::ifstream f(p, std::ios::binary | std::ios::ate); size_t s = f.tellg();
|
| 27 |
+
f.seekg(0); std::vector<uint8_t> v(s); f.read((char*)v.data(), s); return v;
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
int main() {
|
| 31 |
+
const std::string model_dir = "/path/to/Qwen3-235B-A22B-Instruct-2507-BF16";
|
| 32 |
+
const std::string data_dir = "tests/attn_data";
|
| 33 |
+
|
| 34 |
+
ModelConfig cfg;
|
| 35 |
+
if (!cfg.load_from_json(model_dir + "/config.json")) return 1;
|
| 36 |
+
cfg.compute_derived(/*tp_size=*/1, /*tp_rank=*/0); // single rank for correctness test
|
| 37 |
+
const int64_t D = cfg.hidden_size;
|
| 38 |
+
const int64_t Hq = cfg.num_attention_heads;
|
| 39 |
+
const int64_t Hkv = cfg.num_key_value_heads;
|
| 40 |
+
const int64_t Dh = cfg.head_dim;
|
| 41 |
+
const int64_t Q_DIM = Hq * Dh;
|
| 42 |
+
const int64_t KV_DIM = Hkv * Dh;
|
| 43 |
+
const double scale = 1.0 / std::sqrt((double)Dh);
|
| 44 |
+
const double eps = cfg.rms_norm_eps;
|
| 45 |
+
const float theta = cfg.rope_theta;
|
| 46 |
+
|
| 47 |
+
SafetensorsLoader st;
|
| 48 |
+
if (!st.open(model_dir)) return 1;
|
| 49 |
+
|
| 50 |
+
AclRuntime rt;
|
| 51 |
+
rt.init(0);
|
| 52 |
+
|
| 53 |
+
// ---- Load weights (layer 0 attention + embed) ----
|
| 54 |
+
DeviceWeightsLoader dw(st, cfg);
|
| 55 |
+
SharedWeights shared;
|
| 56 |
+
LayerAttnWeights attn;
|
| 57 |
+
printf("Loading weights...\n");
|
| 58 |
+
if (!dw.load_shared(shared)) return 1;
|
| 59 |
+
if (!dw.load_attention(0, attn)) return 1;
|
| 60 |
+
printf(" shared.embed %.0fMB, attn total ~140MB\n", shared.embed_tokens.size / 1e6);
|
| 61 |
+
|
| 62 |
+
// ---- Load token ids (5 tokens: "The capital of France is") ----
|
| 63 |
+
auto tok_raw = read_file(data_dir + "/token_ids.bin");
|
| 64 |
+
int32_t S = *(int32_t*)tok_raw.data();
|
| 65 |
+
std::vector<int32_t> tokens(S);
|
| 66 |
+
std::memcpy(tokens.data(), tok_raw.data() + 4, S * 4);
|
| 67 |
+
printf("S=%d tokens=[", S); for (auto t : tokens) printf("%d,", t); printf("]\n");
|
| 68 |
+
|
| 69 |
+
// ---- Embed lookup: [S, D] ----
|
| 70 |
+
DeviceBuffer tok_dev(S * 4);
|
| 71 |
+
ACL_CHECK(aclrtMemcpy(tok_dev.get(), S * 4, tokens.data(), S * 4, ACL_MEMCPY_HOST_TO_DEVICE));
|
| 72 |
+
auto t_tok = make_contig_tensor(tok_dev.get(), ACL_INT32, {S});
|
| 73 |
+
|
| 74 |
+
// embed weight shape [vocab, D]
|
| 75 |
+
auto t_embed_w = make_contig_tensor(shared.embed_tokens.get(), ACL_BF16, {cfg.vocab_size, D});
|
| 76 |
+
|
| 77 |
+
DeviceBuffer x_dev(S * D * 2);
|
| 78 |
+
auto t_x = make_contig_tensor(x_dev.get(), ACL_BF16, {S, D});
|
| 79 |
+
index_select(rt.stream(), t_embed_w.get(), 0, t_tok.get(), t_x.get());
|
| 80 |
+
rt.sync();
|
| 81 |
+
|
| 82 |
+
// ---- Residual snapshot (copy x) ----
|
| 83 |
+
DeviceBuffer residual_dev(S * D * 2);
|
| 84 |
+
ACL_CHECK(aclrtMemcpyAsync(residual_dev.get(), S*D*2, x_dev.get(), S*D*2, ACL_MEMCPY_DEVICE_TO_DEVICE, rt.stream()));
|
| 85 |
+
|
| 86 |
+
// ---- Input layernorm ----
|
| 87 |
+
DeviceBuffer xn_dev(S * D * 2);
|
| 88 |
+
DeviceBuffer rstd_dev(S * 4);
|
| 89 |
+
auto t_xn = make_contig_tensor(xn_dev.get(), ACL_BF16, {S, D});
|
| 90 |
+
auto t_ln_w = make_contig_tensor(attn.input_layernorm.get(), ACL_BF16, {D});
|
| 91 |
+
auto t_rstd = make_contig_tensor(rstd_dev.get(), ACL_FLOAT, {S});
|
| 92 |
+
rms_norm(rt.stream(), t_x.get(), t_ln_w.get(), eps, t_xn.get(), t_rstd.get());
|
| 93 |
+
|
| 94 |
+
// ---- Q/K/V projections (linear_hf: y = x @ W.T, W stored as [out, in]) ----
|
| 95 |
+
DeviceBuffer q_dev(S * Q_DIM * 2);
|
| 96 |
+
DeviceBuffer k_dev(S * KV_DIM * 2);
|
| 97 |
+
DeviceBuffer v_dev(S * KV_DIM * 2);
|
| 98 |
+
auto t_q = make_contig_tensor(q_dev.get(), ACL_BF16, {S, Q_DIM});
|
| 99 |
+
auto t_k = make_contig_tensor(k_dev.get(), ACL_BF16, {S, KV_DIM});
|
| 100 |
+
auto t_v = make_contig_tensor(v_dev.get(), ACL_BF16, {S, KV_DIM});
|
| 101 |
+
linear_hf(rt.stream(), t_xn.get(), attn.q_proj.get(), ACL_BF16, Q_DIM, D, t_q.get());
|
| 102 |
+
linear_hf(rt.stream(), t_xn.get(), attn.k_proj.get(), ACL_BF16, KV_DIM, D, t_k.get());
|
| 103 |
+
linear_hf(rt.stream(), t_xn.get(), attn.v_proj.get(), ACL_BF16, KV_DIM, D, t_v.get());
|
| 104 |
+
|
| 105 |
+
// ---- Reshape Q, K as [B=1, S, N, Dh] for q_norm/k_norm + RoPE ----
|
| 106 |
+
// Same memory; just new views.
|
| 107 |
+
// q_dev has S * Q_DIM = S * Hq * Dh BF16
|
| 108 |
+
auto t_q_4d = make_contig_tensor(q_dev.get(), ACL_BF16, {1, S, Hq, Dh});
|
| 109 |
+
auto t_k_4d = make_contig_tensor(k_dev.get(), ACL_BF16, {1, S, Hkv, Dh});
|
| 110 |
+
|
| 111 |
+
// Per-head RmsNorm on last dim (gamma shape [Dh])
|
| 112 |
+
auto t_qn_w = make_contig_tensor(attn.q_norm.get(), ACL_BF16, {Dh});
|
| 113 |
+
auto t_kn_w = make_contig_tensor(attn.k_norm.get(), ACL_BF16, {Dh});
|
| 114 |
+
DeviceBuffer rstd_q_dev(S * Hq * 4); // rstd shape = q's all-but-last dims
|
| 115 |
+
DeviceBuffer rstd_k_dev(S * Hkv * 4);
|
| 116 |
+
auto t_rstd_q = make_contig_tensor(rstd_q_dev.get(), ACL_FLOAT, {1, S, Hq});
|
| 117 |
+
auto t_rstd_k = make_contig_tensor(rstd_k_dev.get(), ACL_FLOAT, {1, S, Hkv});
|
| 118 |
+
// RmsNorm in place on q/k
|
| 119 |
+
rms_norm(rt.stream(), t_q_4d.get(), t_qn_w.get(), eps, t_q_4d.get(), t_rstd_q.get());
|
| 120 |
+
rms_norm(rt.stream(), t_k_4d.get(), t_kn_w.get(), eps, t_k_4d.get(), t_rstd_k.get());
|
| 121 |
+
|
| 122 |
+
// ---- Compute cos/sin on device ----
|
| 123 |
+
// cos/sin shape [1, S, Dh] BF16
|
| 124 |
+
std::vector<uint16_t> cos_host(S * Dh), sin_host(S * Dh);
|
| 125 |
+
for (int s = 0; s < S; s++) {
|
| 126 |
+
for (int64_t d = 0; d < Dh; d++) {
|
| 127 |
+
// freq index: for half-half layout, index d corresponds to pair index (d % (Dh/2))
|
| 128 |
+
int64_t half = Dh / 2;
|
| 129 |
+
int64_t pair = (d < half) ? d : (d - half);
|
| 130 |
+
float theta_pair = 1.0f / std::pow(theta, (2.0f * pair) / Dh);
|
| 131 |
+
float angle = (float)s * theta_pair;
|
| 132 |
+
cos_host[s * Dh + d] = float_to_bf16(std::cos(angle));
|
| 133 |
+
sin_host[s * Dh + d] = float_to_bf16(std::sin(angle));
|
| 134 |
+
}
|
| 135 |
+
}
|
| 136 |
+
DeviceBuffer cos_dev(S * Dh * 2);
|
| 137 |
+
DeviceBuffer sin_dev(S * Dh * 2);
|
| 138 |
+
ACL_CHECK(aclrtMemcpy(cos_dev.get(), S*Dh*2, cos_host.data(), S*Dh*2, ACL_MEMCPY_HOST_TO_DEVICE));
|
| 139 |
+
ACL_CHECK(aclrtMemcpy(sin_dev.get(), S*Dh*2, sin_host.data(), S*Dh*2, ACL_MEMCPY_HOST_TO_DEVICE));
|
| 140 |
+
|
| 141 |
+
// ---- RoPE ----
|
| 142 |
+
DeviceBuffer rope_scratch(1 * S * Hq * Dh * 2);
|
| 143 |
+
apply_rope_manual(rt.stream(),
|
| 144 |
+
q_dev.get(), 1, S, Hq, Dh,
|
| 145 |
+
k_dev.get(), Hkv,
|
| 146 |
+
cos_dev.get(), sin_dev.get(),
|
| 147 |
+
rope_scratch.get());
|
| 148 |
+
|
| 149 |
+
// ---- FIAS ----
|
| 150 |
+
// q/k/v are reshaped back to BSH [1, S, Hq*Dh or Hkv*Dh]
|
| 151 |
+
auto t_q_bsh = make_contig_tensor(q_dev.get(), ACL_BF16, {1, S, Q_DIM});
|
| 152 |
+
auto t_k_bsh = make_contig_tensor(k_dev.get(), ACL_BF16, {1, S, KV_DIM});
|
| 153 |
+
auto t_v_bsh = make_contig_tensor(v_dev.get(), ACL_BF16, {1, S, KV_DIM});
|
| 154 |
+
|
| 155 |
+
// Causal mask 2048x2048 (sparse_mode=3 requires fixed size)
|
| 156 |
+
const int64_t MASK = 2048;
|
| 157 |
+
DeviceBuffer mask_dev(MASK * MASK); // bool = 1 byte
|
| 158 |
+
std::vector<uint8_t> mask_host(MASK * MASK, 0);
|
| 159 |
+
for (int i = 0; i < MASK; i++)
|
| 160 |
+
for (int j = i+1; j < MASK; j++)
|
| 161 |
+
mask_host[i*MASK + j] = 1; // upper triangular = True
|
| 162 |
+
ACL_CHECK(aclrtMemcpy(mask_dev.get(), MASK*MASK, mask_host.data(), MASK*MASK, ACL_MEMCPY_HOST_TO_DEVICE));
|
| 163 |
+
auto t_mask = make_contig_tensor(mask_dev.get(), ACL_BOOL, {1, 1, MASK, MASK});
|
| 164 |
+
|
| 165 |
+
DeviceBuffer attn_out_dev(1 * S * Q_DIM * 2);
|
| 166 |
+
auto t_attn_out = make_contig_tensor(attn_out_dev.get(), ACL_BF16, {1, S, Q_DIM});
|
| 167 |
+
|
| 168 |
+
fused_infer_attention_score(
|
| 169 |
+
rt.stream(),
|
| 170 |
+
t_q_bsh.get(), t_k_bsh.get(), t_v_bsh.get(),
|
| 171 |
+
t_mask.get(),
|
| 172 |
+
{S}, {S},
|
| 173 |
+
Hq, Hkv,
|
| 174 |
+
scale,
|
| 175 |
+
3, // sparse_mode = causal
|
| 176 |
+
t_attn_out.get());
|
| 177 |
+
|
| 178 |
+
// ---- O projection ----
|
| 179 |
+
auto t_attn_out_2d = make_contig_tensor(attn_out_dev.get(), ACL_BF16, {S, Q_DIM});
|
| 180 |
+
DeviceBuffer o_dev(S * D * 2);
|
| 181 |
+
auto t_o = make_contig_tensor(o_dev.get(), ACL_BF16, {S, D});
|
| 182 |
+
linear_hf(rt.stream(), t_attn_out_2d.get(), attn.o_proj.get(), ACL_BF16, D, Q_DIM, t_o.get());
|
| 183 |
+
|
| 184 |
+
// ---- Residual add: out = residual + o ----
|
| 185 |
+
auto t_res = make_contig_tensor(residual_dev.get(), ACL_BF16, {S, D});
|
| 186 |
+
float alpha_v = 1.0f;
|
| 187 |
+
aclScalar* alpha = aclCreateScalar(&alpha_v, ACL_FLOAT);
|
| 188 |
+
DeviceBuffer out_dev(S * D * 2);
|
| 189 |
+
auto t_out = make_contig_tensor(out_dev.get(), ACL_BF16, {S, D});
|
| 190 |
+
{
|
| 191 |
+
uint64_t ws = 0; aclOpExecutor* e = nullptr;
|
| 192 |
+
ACLNN_CHECK(aclnnAddGetWorkspaceSize(t_res.get(), t_o.get(), alpha, t_out.get(), &ws, &e));
|
| 193 |
+
DeviceBuffer wb; if (ws > 0) wb.alloc(ws);
|
| 194 |
+
ACLNN_CHECK(aclnnAdd(wb.get(), ws, e, rt.stream()));
|
| 195 |
+
}
|
| 196 |
+
aclDestroyScalar(alpha);
|
| 197 |
+
rt.sync();
|
| 198 |
+
|
| 199 |
+
// ---- Compare with Python reference ----
|
| 200 |
+
auto ref_h = read_file(data_dir + "/final_out.bin");
|
| 201 |
+
std::vector<uint16_t> cxx(S * D);
|
| 202 |
+
ACL_CHECK(aclrtMemcpy(cxx.data(), S*D*2, out_dev.get(), S*D*2, ACL_MEMCPY_DEVICE_TO_HOST));
|
| 203 |
+
auto* ref = (const uint16_t*)ref_h.data();
|
| 204 |
+
|
| 205 |
+
double l2d = 0, l2r = 0, maxd = 0;
|
| 206 |
+
for (int i = 0; i < S * D; i++) {
|
| 207 |
+
float a = bf16_to_float(cxx[i]), b = bf16_to_float(ref[i]);
|
| 208 |
+
l2d += (a-b)*(a-b); l2r += b*b;
|
| 209 |
+
if (std::abs(a-b) > maxd) maxd = std::abs(a-b);
|
| 210 |
+
}
|
| 211 |
+
double rel = std::sqrt(l2d) / (std::sqrt(l2r) + 1e-10);
|
| 212 |
+
printf("\nAttention layer output compare: rel=%.4e max_abs=%.4f\n", rel, maxd);
|
| 213 |
+
printf(" cxx[0, :4]: "); for (int i = 0; i < 4; i++) printf("%.6f ", bf16_to_float(cxx[i]));
|
| 214 |
+
printf("\n ref[0, :4]: "); for (int i = 0; i < 4; i++) printf("%.6f ", bf16_to_float(ref[i])); printf("\n");
|
| 215 |
+
|
| 216 |
+
bool pass = rel < 5e-2; // BF16 accumulation across 5+ ops loses ~1-2% per step
|
| 217 |
+
printf("\n%s\n", pass ? "=== test_attention_layer PASS ===" : "=== test_attention_layer FAIL ===");
|
| 218 |
+
return pass ? 0 : 1;
|
| 219 |
+
}
|
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// test_batch_correctness.cpp — verify that forward with S>1 at past_len>0 produces the
|
| 2 |
+
// same logits at each position as sequential S=1 decodes.
|
| 3 |
+
//
|
| 4 |
+
// This is the foundation for speculative decoding / PLD: the main model must predict logits
|
| 5 |
+
// for each of K candidate positions in one batched forward pass matching sequential behavior.
|
| 6 |
+
#include "runner.h"
|
| 7 |
+
|
| 8 |
+
#include <cstdio>
|
| 9 |
+
#include <cstring>
|
| 10 |
+
#include <vector>
|
| 11 |
+
#include <cmath>
|
| 12 |
+
|
| 13 |
+
static float bf16_to_float(uint16_t x) {
|
| 14 |
+
uint32_t u = (uint32_t)x << 16; float f; std::memcpy(&f, &u, 4); return f;
|
| 15 |
+
}
|
| 16 |
+
|
| 17 |
+
int main() {
|
| 18 |
+
const std::string model_dir = "/path/to/Qwen3-235B-A22B-Instruct-2507-BF16";
|
| 19 |
+
int tp_rank = 0, tp_size = 1;
|
| 20 |
+
if (const char* v = std::getenv("TP_RANK")) tp_rank = std::atoi(v);
|
| 21 |
+
if (const char* v = std::getenv("TP_SIZE")) tp_size = std::atoi(v);
|
| 22 |
+
bool is_master = tp_rank == 0;
|
| 23 |
+
|
| 24 |
+
Runner r;
|
| 25 |
+
if (!r.init(model_dir, tp_size, tp_rank, 94, 512)) return 1;
|
| 26 |
+
const int64_t V = r.cfg().vocab_size;
|
| 27 |
+
|
| 28 |
+
// Prefix
|
| 29 |
+
std::vector<int32_t> prompt = {785, 6722, 315, 9625, 374};
|
| 30 |
+
DeviceBuffer logits0;
|
| 31 |
+
r.prefill(prompt.data(), prompt.size(), logits0);
|
| 32 |
+
std::vector<uint16_t> h_last0(V);
|
| 33 |
+
if (is_master) ACL_CHECK(aclrtMemcpy(h_last0.data(), V*2, logits0.get(), V*2, ACL_MEMCPY_DEVICE_TO_HOST));
|
| 34 |
+
int next0 = 0;
|
| 35 |
+
if (is_master) {
|
| 36 |
+
float best = -1e30; for (int i = 0; i < V; i++) { float v = bf16_to_float(h_last0[i]); if (v > best) { best = v; next0 = i; } }
|
| 37 |
+
}
|
| 38 |
+
// Broadcast next0 (simple: let rank 0 decide and non-master ranks independently too)
|
| 39 |
+
int32_t token_seq[4];
|
| 40 |
+
if (is_master) token_seq[0] = next0;
|
| 41 |
+
|
| 42 |
+
// --- Path A: sequential S=1 decode × 4 times ---
|
| 43 |
+
std::vector<std::vector<uint16_t>> seq_logits(4);
|
| 44 |
+
for (int i = 0; i < 4; i++) seq_logits[i].resize(V);
|
| 45 |
+
|
| 46 |
+
// first decode: takes prompt's last logit argmax
|
| 47 |
+
// Here we need identical approach on all ranks. Use random token id for consistency.
|
| 48 |
+
std::vector<int32_t> seq_tokens = {next0, 100, 200, 300}; // deterministic for test
|
| 49 |
+
|
| 50 |
+
for (int i = 0; i < 4; i++) {
|
| 51 |
+
DeviceBuffer out;
|
| 52 |
+
r.decode(seq_tokens[i], out);
|
| 53 |
+
if (is_master) ACL_CHECK(aclrtMemcpy(seq_logits[i].data(), V*2, out.get(), V*2, ACL_MEMCPY_DEVICE_TO_HOST));
|
| 54 |
+
}
|
| 55 |
+
int64_t past_after_seq = r.past_len();
|
| 56 |
+
|
| 57 |
+
// --- Path B: reset, re-prefill, then ONE batch forward with S=4 ---
|
| 58 |
+
r.reset_cache();
|
| 59 |
+
DeviceBuffer logits_reprefill;
|
| 60 |
+
r.prefill(prompt.data(), prompt.size(), logits_reprefill);
|
| 61 |
+
|
| 62 |
+
DeviceBuffer batch_logits;
|
| 63 |
+
r.prefill(seq_tokens.data(), 4, batch_logits);
|
| 64 |
+
// prefill returns logits for LAST position only (S=4 gives [1, V], not [4, V]).
|
| 65 |
+
// Hmm — that's a limitation. To do PLD we need logits for all 4 positions.
|
| 66 |
+
// For now, just compare the LAST one (position 4 after prefix).
|
| 67 |
+
|
| 68 |
+
std::vector<uint16_t> batch_last(V);
|
| 69 |
+
if (is_master) ACL_CHECK(aclrtMemcpy(batch_last.data(), V*2, batch_logits.get(), V*2, ACL_MEMCPY_DEVICE_TO_HOST));
|
| 70 |
+
|
| 71 |
+
if (is_master) {
|
| 72 |
+
printf("\n=== Batched vs Sequential Decode Correctness ===\n");
|
| 73 |
+
double l2d=0, l2r=0, maxd=0;
|
| 74 |
+
for (int i = 0; i < V; i++) {
|
| 75 |
+
float a = bf16_to_float(batch_last[i]), b = bf16_to_float(seq_logits[3][i]);
|
| 76 |
+
l2d += (a-b)*(a-b); l2r += b*b;
|
| 77 |
+
if (std::abs(a-b) > maxd) maxd = std::abs(a-b);
|
| 78 |
+
}
|
| 79 |
+
double rel = std::sqrt(l2d) / (std::sqrt(l2r) + 1e-10);
|
| 80 |
+
printf("Last-position logits:\n");
|
| 81 |
+
printf(" seq[3] argmax = "); {
|
| 82 |
+
int b = 0; float bv = bf16_to_float(seq_logits[3][0]);
|
| 83 |
+
for (int i = 1; i < V; i++) if (bf16_to_float(seq_logits[3][i]) > bv) { bv = bf16_to_float(seq_logits[3][i]); b = i; }
|
| 84 |
+
printf("%d (%.3f)\n", b, bv);
|
| 85 |
+
}
|
| 86 |
+
printf(" batch argmax = "); {
|
| 87 |
+
int b = 0; float bv = bf16_to_float(batch_last[0]);
|
| 88 |
+
for (int i = 1; i < V; i++) if (bf16_to_float(batch_last[i]) > bv) { bv = bf16_to_float(batch_last[i]); b = i; }
|
| 89 |
+
printf("%d (%.3f)\n", b, bv);
|
| 90 |
+
}
|
| 91 |
+
printf(" rel=%.4e max=%.4f\n", rel, maxd);
|
| 92 |
+
printf(" %s\n", rel < 5e-2 ? "PASS" : "FAIL (batch forward diverges from sequential)");
|
| 93 |
+
printf("\nNote: current Runner.prefill() returns ONLY last-position logits. For PLD\n");
|
| 94 |
+
printf("we need all-position logits: requires extending prefill to optionally output\n");
|
| 95 |
+
printf("[S, V] logits tensor.\n");
|
| 96 |
+
}
|
| 97 |
+
return 0;
|
| 98 |
+
}
|
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// test_batch_decode.cpp — benchmark decode with different batch sizes S = 1, 2, 4, 8.
|
| 2 |
+
//
|
| 3 |
+
// Purpose: quantify the cost of "batched decode" (a.k.a. the ingredient speculative decoding
|
| 4 |
+
// relies on). If Runner.prefill(S=K) forward-pass is only a small overhead over S=1, then
|
| 5 |
+
// spec-decoding with K draft tokens gives ~K× speedup at high accept rate.
|
| 6 |
+
//
|
| 7 |
+
// Per-token amortized cost:
|
| 8 |
+
// cost(S) / S
|
| 9 |
+
// Speculative decoding benefit:
|
| 10 |
+
// expected_accept_rate * K = effective tokens per forward
|
| 11 |
+
// → TG = expected / (cost(S=K+1) / 1_sec)
|
| 12 |
+
#include "runner.h"
|
| 13 |
+
|
| 14 |
+
#include <chrono>
|
| 15 |
+
#include <cstdio>
|
| 16 |
+
#include <cstring>
|
| 17 |
+
#include <vector>
|
| 18 |
+
|
| 19 |
+
int main() {
|
| 20 |
+
const std::string model_dir = "/path/to/Qwen3-235B-A22B-Instruct-2507-BF16";
|
| 21 |
+
Runner r;
|
| 22 |
+
int tp_rank = 0, tp_size = 1;
|
| 23 |
+
if (const char* v = std::getenv("TP_RANK")) tp_rank = std::atoi(v);
|
| 24 |
+
if (const char* v = std::getenv("TP_SIZE")) tp_size = std::atoi(v);
|
| 25 |
+
bool is_master = tp_rank == 0;
|
| 26 |
+
|
| 27 |
+
if (!r.init(model_dir, tp_size, tp_rank, /*num_layers=*/94, /*max_seq=*/512)) return 1;
|
| 28 |
+
|
| 29 |
+
// Prefill a short context so decode has some KV cache
|
| 30 |
+
std::vector<int32_t> prompt = {785, 6722, 315, 9625, 374}; // "The capital of France is"
|
| 31 |
+
DeviceBuffer logits;
|
| 32 |
+
r.prefill(prompt.data(), prompt.size(), logits);
|
| 33 |
+
|
| 34 |
+
auto now = []() { return std::chrono::steady_clock::now(); };
|
| 35 |
+
auto ms = [](auto t0, auto t1) { return std::chrono::duration<double, std::milli>(t1 - t0).count(); };
|
| 36 |
+
|
| 37 |
+
std::vector<int> batch_sizes = {1, 2, 4, 8};
|
| 38 |
+
int N_ITERS = 20;
|
| 39 |
+
|
| 40 |
+
if (is_master) {
|
| 41 |
+
printf("\n=== Batched decode forward benchmark (94 layers, TP=%d) ===\n", tp_size);
|
| 42 |
+
printf("Each row: forward with S=K new tokens after prefill\n");
|
| 43 |
+
printf("%-5s %-12s %-18s %-18s %s\n",
|
| 44 |
+
"S", "ms/forward", "ms/token (amort)", "tokens/sec", "vs S=1 efficiency");
|
| 45 |
+
}
|
| 46 |
+
|
| 47 |
+
double base_per_token = 0;
|
| 48 |
+
for (int S : batch_sizes) {
|
| 49 |
+
// Reset cache between measurements to keep cache size fair (same position for each)
|
| 50 |
+
// Actually we want to simulate: after some past_len, do 1 forward with S new tokens.
|
| 51 |
+
// Use prefill() which accepts S>=1.
|
| 52 |
+
|
| 53 |
+
std::vector<double> times;
|
| 54 |
+
for (int iter = 0; iter < N_ITERS + 3; iter++) { // +3 for warmup
|
| 55 |
+
r.reset_cache();
|
| 56 |
+
r.prefill(prompt.data(), prompt.size(), logits); // re-prefill
|
| 57 |
+
|
| 58 |
+
// New forward with S tokens (as if doing speculative verify)
|
| 59 |
+
std::vector<int32_t> new_tokens(S, 100); // dummy token ids
|
| 60 |
+
auto t0 = now();
|
| 61 |
+
DeviceBuffer logits2;
|
| 62 |
+
r.prefill(new_tokens.data(), S, logits2);
|
| 63 |
+
auto t1 = now();
|
| 64 |
+
if (iter >= 3) times.push_back(ms(t0, t1));
|
| 65 |
+
}
|
| 66 |
+
std::sort(times.begin(), times.end());
|
| 67 |
+
double median_ms = times[times.size() / 2];
|
| 68 |
+
double per_token = median_ms / S;
|
| 69 |
+
double tok_per_sec = 1000.0 / per_token;
|
| 70 |
+
if (S == 1) base_per_token = per_token;
|
| 71 |
+
double efficiency = base_per_token / per_token * 100.0;
|
| 72 |
+
|
| 73 |
+
if (is_master) {
|
| 74 |
+
printf("%-5d %-12.2f %-18.2f %-18.2f %.1f%%\n",
|
| 75 |
+
S, median_ms, per_token, tok_per_sec, efficiency);
|
| 76 |
+
}
|
| 77 |
+
}
|
| 78 |
+
|
| 79 |
+
if (is_master) {
|
| 80 |
+
printf("\n=== Interpretation ===\n");
|
| 81 |
+
printf("If S=4 forward ~ S=1 (efficiency high), spec decoding with accept_rate=70%%\n");
|
| 82 |
+
printf("gives TG = 0.7*4 / cost(S=5) vs baseline 1 / cost(S=1) = up to 2.8× speedup.\n");
|
| 83 |
+
}
|
| 84 |
+
return 0;
|
| 85 |
+
}
|
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
# test_chat_flow.sh — end-to-end integration smoke test for the CLI.
|
| 3 |
+
#
|
| 4 |
+
# Exercises:
|
| 5 |
+
# - --prompt-file
|
| 6 |
+
# - Multi-turn --chat memory (remembers Alice's name in turn 2)
|
| 7 |
+
# - --reset command in REPL
|
| 8 |
+
# - --system prompt
|
| 9 |
+
# - EOS detection at <|im_end|>
|
| 10 |
+
#
|
| 11 |
+
# Requires TP=16 Ascend 910 setup. Run from the repo root.
|
| 12 |
+
#
|
| 13 |
+
# Exit: 0 on all-pass, nonzero with reason.
|
| 14 |
+
set -u
|
| 15 |
+
BIN="./build/qwen3-moe-aclnn"
|
| 16 |
+
MODEL="${MODEL_DIR:-/path/to/Qwen3-235B-A22B-Instruct-2507-BF16}"
|
| 17 |
+
LAUNCH="./scripts/tp_launch.sh"
|
| 18 |
+
TP="${TP_SIZE:-16}"
|
| 19 |
+
VOCAB="tokenizer_data/vocab.bin"
|
| 20 |
+
|
| 21 |
+
[ -x "$BIN" ] || { echo "FAIL: $BIN not built"; exit 1; }
|
| 22 |
+
[ -x "$LAUNCH" ] || { echo "FAIL: $LAUNCH not found"; exit 1; }
|
| 23 |
+
|
| 24 |
+
pass=0; fail=0
|
| 25 |
+
check() {
|
| 26 |
+
local name="$1"; shift
|
| 27 |
+
local out="$1"; shift
|
| 28 |
+
local needle="$1"; shift
|
| 29 |
+
if echo "$out" | grep -qiF "$needle"; then
|
| 30 |
+
echo " [PASS] $name (found: '$needle')"; pass=$((pass+1))
|
| 31 |
+
else
|
| 32 |
+
echo " [FAIL] $name (did NOT find: '$needle')"; fail=$((fail+1))
|
| 33 |
+
echo " ---- output ----"; echo "$out" | tail -20; echo " ---- end ----"
|
| 34 |
+
fi
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
echo "===== Test 1: --prompt-file + EOS ====="
|
| 38 |
+
echo "What is the capital of Japan?" > /tmp/chat_test_prompt.txt
|
| 39 |
+
OUT=$(${LAUNCH} ${TP} ${BIN} --model-dir "$MODEL" \
|
| 40 |
+
--prompt-file /tmp/chat_test_prompt.txt \
|
| 41 |
+
--chat --n-predict 50 --temperature 0 --vocab "$VOCAB" 2>&1)
|
| 42 |
+
check "prompt-file loaded" "$OUT" "capital of Japan"
|
| 43 |
+
check "answer mentions Tokyo" "$OUT" "Tokyo"
|
| 44 |
+
check "hit EOS" "$OUT" "hit EOS"
|
| 45 |
+
|
| 46 |
+
echo ""
|
| 47 |
+
echo "===== Test 2: multi-turn memory (remembers name) ====="
|
| 48 |
+
OUT=$(printf "My name is Alice.\nWhat is my name?\nquit\n" | \
|
| 49 |
+
${LAUNCH} ${TP} ${BIN} --model-dir "$MODEL" \
|
| 50 |
+
--interactive --chat \
|
| 51 |
+
--system "You are a concise assistant. Answer in one short sentence." \
|
| 52 |
+
--temperature 0 --n-predict 40 --max-seq 512 \
|
| 53 |
+
--vocab "$VOCAB" 2>&1)
|
| 54 |
+
check "recalls Alice" "$OUT" "Alice"
|
| 55 |
+
check "has 2 turns" "$OUT" "past_len="
|
| 56 |
+
|
| 57 |
+
echo ""
|
| 58 |
+
echo "===== Test 3: reset command clears memory ====="
|
| 59 |
+
OUT=$(printf "My name is Bob.\nreset\nWhat is my name?\nquit\n" | \
|
| 60 |
+
${LAUNCH} ${TP} ${BIN} --model-dir "$MODEL" \
|
| 61 |
+
--interactive --chat \
|
| 62 |
+
--system "Answer truthfully in one sentence." \
|
| 63 |
+
--temperature 0 --n-predict 40 --max-seq 512 \
|
| 64 |
+
--vocab "$VOCAB" 2>&1)
|
| 65 |
+
check "reset acknowledged" "$OUT" "cache + conversation reset"
|
| 66 |
+
# After reset, model should NOT know the name is Bob (probably says "don't know" or asks)
|
| 67 |
+
# We can't reliably check negation, so just check that the reset ran and turn 3 produced output
|
| 68 |
+
check "turn 3 ran" "$OUT" "bye"
|
| 69 |
+
|
| 70 |
+
echo ""
|
| 71 |
+
echo "===== Summary: $pass passed, $fail failed ====="
|
| 72 |
+
exit $fail
|
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// test_engine_smoke.cpp — just verify engine.h compiles and links.
|
| 2 |
+
#include "engine.h"
|
| 3 |
+
|
| 4 |
+
int main() {
|
| 5 |
+
// No-op — all engine functions take many parameters and need real runtime. This test
|
| 6 |
+
// only validates that the header compiles and the core lib links.
|
| 7 |
+
return 0;
|
| 8 |
+
}
|
|
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// test_layer_forward.cpp — integration test for one full transformer layer via engine.h.
|
| 2 |
+
//
|
| 3 |
+
// Chain: embed_5_tokens → attention_forward (prefill, past=0) → +residual → moe_forward → +residual
|
| 4 |
+
// Expected: final output matches moe_data/final_out.bin within BF16 precision (rel < 5e-2).
|
| 5 |
+
#include "acl_common.h"
|
| 6 |
+
#include "acl_runtime.h"
|
| 7 |
+
#include "aclnn_ops.h"
|
| 8 |
+
#include "device_weights.h"
|
| 9 |
+
#include "engine.h"
|
| 10 |
+
#include "model_config.h"
|
| 11 |
+
#include "safetensors_loader.h"
|
| 12 |
+
|
| 13 |
+
#include <cmath>
|
| 14 |
+
#include <cstdio>
|
| 15 |
+
#include <cstring>
|
| 16 |
+
#include <fstream>
|
| 17 |
+
#include <vector>
|
| 18 |
+
|
| 19 |
+
static float bf16_to_float(uint16_t x) {
|
| 20 |
+
uint32_t u = (uint32_t)x << 16; float f; std::memcpy(&f, &u, 4); return f;
|
| 21 |
+
}
|
| 22 |
+
static std::vector<uint8_t> read_file(const std::string& p) {
|
| 23 |
+
std::ifstream f(p, std::ios::binary | std::ios::ate); size_t s = f.tellg();
|
| 24 |
+
f.seekg(0); std::vector<uint8_t> v(s); f.read((char*)v.data(), s); return v;
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
// Add: out = a + b (BF16).
|
| 28 |
+
static void bf16_add(aclrtStream stream, aclTensor* a, aclTensor* b, aclTensor* out) {
|
| 29 |
+
float alpha = 1.0f; aclScalar* al = aclCreateScalar(&alpha, ACL_FLOAT);
|
| 30 |
+
uint64_t ws = 0; aclOpExecutor* e = nullptr;
|
| 31 |
+
ACLNN_CHECK(aclnnAddGetWorkspaceSize(a, b, al, out, &ws, &e));
|
| 32 |
+
DeviceBuffer wb; if (ws > 0) wb.alloc(ws);
|
| 33 |
+
ACLNN_CHECK(aclnnAdd(wb.get(), ws, e, stream));
|
| 34 |
+
aclDestroyScalar(al);
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
int main() {
|
| 38 |
+
const std::string model_dir = "/path/to/Qwen3-235B-A22B-Instruct-2507-BF16";
|
| 39 |
+
const std::string attn_data = "tests/attn_data";
|
| 40 |
+
const std::string moe_data = "tests/moe_data";
|
| 41 |
+
|
| 42 |
+
ModelConfig cfg;
|
| 43 |
+
if (!cfg.load_from_json(model_dir + "/config.json")) return 1;
|
| 44 |
+
cfg.compute_derived(1, 0);
|
| 45 |
+
const int64_t D = cfg.hidden_size;
|
| 46 |
+
const int64_t Hq = cfg.n_heads_per_rank;
|
| 47 |
+
const int64_t Hkv = cfg.n_kv_heads_per_rank;
|
| 48 |
+
const int64_t Dh = cfg.head_dim;
|
| 49 |
+
const int64_t Q_DIM = Hq * Dh;
|
| 50 |
+
const int64_t KV_DIM = Hkv * Dh;
|
| 51 |
+
const int64_t I = cfg.i_per_rank;
|
| 52 |
+
const int64_t E = cfg.num_experts;
|
| 53 |
+
const int64_t K = cfg.num_experts_per_tok;
|
| 54 |
+
printf("Dims: D=%ld Q_DIM=%ld KV_DIM=%ld I=%ld E=%ld K=%ld\n", D, Q_DIM, KV_DIM, I, E, K);
|
| 55 |
+
|
| 56 |
+
SafetensorsLoader st;
|
| 57 |
+
if (!st.open(model_dir)) return 1;
|
| 58 |
+
AclRuntime rt;
|
| 59 |
+
rt.init(0);
|
| 60 |
+
|
| 61 |
+
DeviceWeightsLoader dw(st, cfg);
|
| 62 |
+
SharedWeights shared;
|
| 63 |
+
LayerAttnWeights attn;
|
| 64 |
+
LayerMoEWeights moe;
|
| 65 |
+
printf("Loading weights...\n");
|
| 66 |
+
if (!dw.load_shared(shared)) return 1;
|
| 67 |
+
if (!dw.load_attention(0, attn)) return 1;
|
| 68 |
+
if (!dw.load_moe(0, rt.stream(), moe)) return 1;
|
| 69 |
+
rt.sync();
|
| 70 |
+
|
| 71 |
+
// ---- Load 5 prefill tokens ----
|
| 72 |
+
auto tok_raw = read_file(attn_data + "/token_ids.bin");
|
| 73 |
+
int32_t S = *(int32_t*)tok_raw.data();
|
| 74 |
+
std::vector<int32_t> tokens(S);
|
| 75 |
+
std::memcpy(tokens.data(), tok_raw.data() + 4, S * 4);
|
| 76 |
+
printf("S=%d tokens=[", S); for (auto t : tokens) printf("%d,", t); printf("]\n");
|
| 77 |
+
|
| 78 |
+
// ---- Embed ----
|
| 79 |
+
DeviceBuffer tok_dev(S * 4);
|
| 80 |
+
ACL_CHECK(aclrtMemcpy(tok_dev.get(), S * 4, tokens.data(), S * 4, ACL_MEMCPY_HOST_TO_DEVICE));
|
| 81 |
+
auto t_tok = make_contig_tensor(tok_dev.get(), ACL_INT32, {S});
|
| 82 |
+
auto t_embed_w = make_contig_tensor(shared.embed_tokens.get(), ACL_BF16, {cfg.vocab_size, D});
|
| 83 |
+
|
| 84 |
+
DeviceBuffer x_dev(S * D * 2); // residual / input to layer
|
| 85 |
+
auto t_x = make_contig_tensor(x_dev.get(), ACL_BF16, {S, D});
|
| 86 |
+
index_select(rt.stream(), t_embed_w.get(), 0, t_tok.get(), t_x.get());
|
| 87 |
+
rt.sync();
|
| 88 |
+
|
| 89 |
+
// ---- Scratch buffers for attention_forward ----
|
| 90 |
+
const int64_t MAX_LEN = 128;
|
| 91 |
+
DeviceBuffer k_cache(MAX_LEN * KV_DIM * 2), v_cache(MAX_LEN * KV_DIM * 2);
|
| 92 |
+
DeviceBuffer q_sc(S * Q_DIM * 2), k_sc(S * KV_DIM * 2), v_sc(S * KV_DIM * 2);
|
| 93 |
+
DeviceBuffer xn_sc(S * D * 2), rstd_sc(S * std::max(Hq, Hkv) * 4);
|
| 94 |
+
DeviceBuffer rope_sc(1 * S * Hq * Dh * 2);
|
| 95 |
+
DeviceBuffer attn_fias_sc(S * Q_DIM * 2); // FIAS output buffer (before o_proj)
|
| 96 |
+
DeviceBuffer attn_out_dev(S * D * 2);
|
| 97 |
+
|
| 98 |
+
// ---- Causal mask (2048x2048) for prefill ----
|
| 99 |
+
const int64_t MASK = 2048;
|
| 100 |
+
DeviceBuffer mask_dev(MASK * MASK);
|
| 101 |
+
std::vector<uint8_t> mh(MASK * MASK, 0);
|
| 102 |
+
for (int i = 0; i < MASK; i++)
|
| 103 |
+
for (int j = i+1; j < MASK; j++) mh[i*MASK + j] = 1;
|
| 104 |
+
ACL_CHECK(aclrtMemcpy(mask_dev.get(), MASK*MASK, mh.data(), MASK*MASK, ACL_MEMCPY_HOST_TO_DEVICE));
|
| 105 |
+
auto t_mask = make_contig_tensor(mask_dev.get(), ACL_BOOL, {1, 1, MASK, MASK});
|
| 106 |
+
|
| 107 |
+
// ---- Attention forward ----
|
| 108 |
+
attention_forward(
|
| 109 |
+
rt.stream(), cfg, attn,
|
| 110 |
+
x_dev.get(), S,
|
| 111 |
+
/*past_len=*/0, k_cache.get(), v_cache.get(), MAX_LEN,
|
| 112 |
+
t_mask.get(),
|
| 113 |
+
q_sc.get(), k_sc.get(), v_sc.get(),
|
| 114 |
+
xn_sc.get(), rstd_sc.get(), rope_sc.get(),
|
| 115 |
+
attn_fias_sc.get(),
|
| 116 |
+
attn_out_dev.get());
|
| 117 |
+
rt.sync();
|
| 118 |
+
|
| 119 |
+
// ---- x1 = x + attn_out (residual) — should match attn_data/final_out.bin ----
|
| 120 |
+
DeviceBuffer x1_dev(S * D * 2);
|
| 121 |
+
auto t_attn_out = make_contig_tensor(attn_out_dev.get(), ACL_BF16, {S, D});
|
| 122 |
+
auto t_x1 = make_contig_tensor(x1_dev.get(), ACL_BF16, {S, D});
|
| 123 |
+
bf16_add(rt.stream(), t_x.get(), t_attn_out.get(), t_x1.get());
|
| 124 |
+
rt.sync();
|
| 125 |
+
|
| 126 |
+
auto attn_ref_h = read_file(attn_data + "/final_out.bin");
|
| 127 |
+
std::vector<uint16_t> x1_host(S * D);
|
| 128 |
+
ACL_CHECK(aclrtMemcpy(x1_host.data(), S*D*2, x1_dev.get(), S*D*2, ACL_MEMCPY_DEVICE_TO_HOST));
|
| 129 |
+
auto* ar = (const uint16_t*)attn_ref_h.data();
|
| 130 |
+
double al2d=0, al2r=0, amaxd=0;
|
| 131 |
+
for (int i = 0; i < S*D; i++) {
|
| 132 |
+
float a = bf16_to_float(x1_host[i]), b = bf16_to_float(ar[i]);
|
| 133 |
+
al2d += (a-b)*(a-b); al2r += b*b;
|
| 134 |
+
if (std::abs(a-b) > amaxd) amaxd = std::abs(a-b);
|
| 135 |
+
}
|
| 136 |
+
double arel = std::sqrt(al2d) / (std::sqrt(al2r) + 1e-10);
|
| 137 |
+
printf(" [attn] x + attn_out vs attn_data/final_out.bin: rel=%.4e max=%.4f\n", arel, amaxd);
|
| 138 |
+
|
| 139 |
+
// ---- MoE scratch buffers ----
|
| 140 |
+
const int64_t TOTAL = S * K;
|
| 141 |
+
DeviceBuffer moe_xn(S * D * 2), moe_rstd(S * 4);
|
| 142 |
+
DeviceBuffer moe_logits(S * E * 2);
|
| 143 |
+
DeviceBuffer moe_topk_w(S * K * 2), moe_topk_idx(S * K * 4), moe_row_idx(S * K * 4);
|
| 144 |
+
DeviceBuffer moe_ex_x(TOTAL * D * 2), moe_ex_ri(TOTAL * 4), moe_tpe(E * 8);
|
| 145 |
+
DeviceBuffer moe_fwd(TOTAL * 8);
|
| 146 |
+
DeviceBuffer moe_gate(TOTAL * I * 2), moe_up(TOTAL * I * 2), moe_down(TOTAL * D * 2);
|
| 147 |
+
DeviceBuffer moe_packed(TOTAL * D * 2), moe_weighted(S * K * D * 2);
|
| 148 |
+
DeviceBuffer moe_out_dev(S * D * 2);
|
| 149 |
+
|
| 150 |
+
moe_forward(rt.stream(), cfg, attn, moe,
|
| 151 |
+
x1_dev.get(), S,
|
| 152 |
+
moe_xn.get(), moe_rstd.get(),
|
| 153 |
+
moe_logits.get(),
|
| 154 |
+
moe_topk_w.get(), moe_topk_idx.get(), moe_row_idx.get(),
|
| 155 |
+
moe_ex_x.get(), moe_ex_ri.get(), moe_tpe.get(),
|
| 156 |
+
moe_fwd.get(),
|
| 157 |
+
moe_gate.get(), moe_up.get(), moe_down.get(),
|
| 158 |
+
moe_packed.get(), moe_weighted.get(),
|
| 159 |
+
moe_out_dev.get());
|
| 160 |
+
rt.sync();
|
| 161 |
+
|
| 162 |
+
// ---- x2 = x1 + moe_out (residual) — should match moe_data/final_out.bin ----
|
| 163 |
+
DeviceBuffer x2_dev(S * D * 2);
|
| 164 |
+
auto t_moe_out = make_contig_tensor(moe_out_dev.get(), ACL_BF16, {S, D});
|
| 165 |
+
auto t_x2 = make_contig_tensor(x2_dev.get(), ACL_BF16, {S, D});
|
| 166 |
+
bf16_add(rt.stream(), t_x1.get(), t_moe_out.get(), t_x2.get());
|
| 167 |
+
rt.sync();
|
| 168 |
+
|
| 169 |
+
auto moe_ref_h = read_file(moe_data + "/final_out.bin");
|
| 170 |
+
std::vector<uint16_t> x2_host(S * D);
|
| 171 |
+
ACL_CHECK(aclrtMemcpy(x2_host.data(), S*D*2, x2_dev.get(), S*D*2, ACL_MEMCPY_DEVICE_TO_HOST));
|
| 172 |
+
auto* mr = (const uint16_t*)moe_ref_h.data();
|
| 173 |
+
double ml2d=0, ml2r=0, mmaxd=0;
|
| 174 |
+
for (int i = 0; i < S*D; i++) {
|
| 175 |
+
float a = bf16_to_float(x2_host[i]), b = bf16_to_float(mr[i]);
|
| 176 |
+
ml2d += (a-b)*(a-b); ml2r += b*b;
|
| 177 |
+
if (std::abs(a-b) > mmaxd) mmaxd = std::abs(a-b);
|
| 178 |
+
}
|
| 179 |
+
double mrel = std::sqrt(ml2d) / (std::sqrt(ml2r) + 1e-10);
|
| 180 |
+
printf(" [full] x1 + moe_out vs moe_data/final_out.bin: rel=%.4e max=%.4f\n", mrel, mmaxd);
|
| 181 |
+
printf(" x2[0, :4]: %.5f %.5f %.5f %.5f\n",
|
| 182 |
+
bf16_to_float(x2_host[0]), bf16_to_float(x2_host[1]), bf16_to_float(x2_host[2]), bf16_to_float(x2_host[3]));
|
| 183 |
+
printf(" ref[0, :4]: %.5f %.5f %.5f %.5f\n",
|
| 184 |
+
bf16_to_float(mr[0]), bf16_to_float(mr[1]), bf16_to_float(mr[2]), bf16_to_float(mr[3]));
|
| 185 |
+
|
| 186 |
+
// Tolerance: attn chain 5e-3 (tight, only linear ops); full layer 1e-1 (MoE's discrete topk
|
| 187 |
+
// routing amplifies BF16 noise — tiny input changes flip expert selection, magnifying output
|
| 188 |
+
// delta. End-to-end CLI correctness is validated by test_chat_flow.sh separately.)
|
| 189 |
+
bool pass = (arel < 5e-3) && (mrel < 1e-1);
|
| 190 |
+
printf("\n%s\n", pass ? "=== test_layer_forward PASS ===" : "=== test_layer_forward FAIL ===");
|
| 191 |
+
return pass ? 0 : 1;
|
| 192 |
+
}
|
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// test_linear_hf.cpp — verify linear_hf (y = x @ W.T with HF [out, in] layout).
|
| 2 |
+
#include "acl_common.h"
|
| 3 |
+
#include "acl_runtime.h"
|
| 4 |
+
#include "aclnn_ops.h"
|
| 5 |
+
|
| 6 |
+
#include <cmath>
|
| 7 |
+
#include <cstdio>
|
| 8 |
+
#include <cstring>
|
| 9 |
+
#include <fstream>
|
| 10 |
+
#include <vector>
|
| 11 |
+
|
| 12 |
+
static float bf16_to_float(uint16_t x) {
|
| 13 |
+
uint32_t u = (uint32_t)x << 16; float f; std::memcpy(&f, &u, 4); return f;
|
| 14 |
+
}
|
| 15 |
+
|
| 16 |
+
int main() {
|
| 17 |
+
const std::string data = "tests/mm_data";
|
| 18 |
+
int64_t N = 0, D = 0, OUT = 0;
|
| 19 |
+
{
|
| 20 |
+
std::ifstream f(data + "/shape.txt"); std::string line;
|
| 21 |
+
while (std::getline(f, line)) {
|
| 22 |
+
auto eq = line.find('='); if (eq == std::string::npos) continue;
|
| 23 |
+
auto k = line.substr(0, eq); auto v = std::atoll(line.c_str() + eq + 1);
|
| 24 |
+
if (k == "N") N = v; else if (k == "D") D = v; else if (k == "OUT") OUT = v;
|
| 25 |
+
}
|
| 26 |
+
}
|
| 27 |
+
printf("N=%ld D=%ld OUT=%ld\n", N, D, OUT);
|
| 28 |
+
|
| 29 |
+
AclRuntime rt;
|
| 30 |
+
rt.init(0);
|
| 31 |
+
|
| 32 |
+
auto read_all = [&](const std::string& p) {
|
| 33 |
+
std::ifstream f(p, std::ios::binary | std::ios::ate); size_t sz = f.tellg();
|
| 34 |
+
f.seekg(0); std::vector<uint8_t> v(sz); f.read((char*)v.data(), sz); return v;
|
| 35 |
+
};
|
| 36 |
+
auto x_h = read_all(data + "/x.bin");
|
| 37 |
+
auto W_h = read_all(data + "/W.bin");
|
| 38 |
+
auto yr_h = read_all(data + "/y_ref.bin");
|
| 39 |
+
|
| 40 |
+
DeviceBuffer x_d(N * D * 2);
|
| 41 |
+
DeviceBuffer W_d(OUT * D * 2);
|
| 42 |
+
DeviceBuffer y_d(N * OUT * 2);
|
| 43 |
+
ACL_CHECK(aclrtMemcpy(x_d.get(), x_h.size(), x_h.data(), x_h.size(), ACL_MEMCPY_HOST_TO_DEVICE));
|
| 44 |
+
ACL_CHECK(aclrtMemcpy(W_d.get(), W_h.size(), W_h.data(), W_h.size(), ACL_MEMCPY_HOST_TO_DEVICE));
|
| 45 |
+
|
| 46 |
+
auto t_x = make_contig_tensor(x_d.get(), ACL_BF16, {N, D});
|
| 47 |
+
auto t_y = make_contig_tensor(y_d.get(), ACL_BF16, {N, OUT});
|
| 48 |
+
|
| 49 |
+
linear_hf(rt.stream(), t_x.get(), W_d.get(), ACL_BF16, OUT, D, t_y.get());
|
| 50 |
+
rt.sync();
|
| 51 |
+
|
| 52 |
+
std::vector<uint16_t> y_cxx(N * OUT);
|
| 53 |
+
ACL_CHECK(aclrtMemcpy(y_cxx.data(), N * OUT * 2, y_d.get(), N * OUT * 2, ACL_MEMCPY_DEVICE_TO_HOST));
|
| 54 |
+
auto* y_ref = (const uint16_t*)yr_h.data();
|
| 55 |
+
|
| 56 |
+
double l2d = 0, l2r = 0, maxd = 0;
|
| 57 |
+
for (int i = 0; i < N * OUT; i++) {
|
| 58 |
+
float a = bf16_to_float(y_cxx[i]);
|
| 59 |
+
float b = bf16_to_float(y_ref[i]);
|
| 60 |
+
l2d += (a-b)*(a-b); l2r += b*b;
|
| 61 |
+
if (std::abs(a-b) > maxd) maxd = std::abs(a-b);
|
| 62 |
+
}
|
| 63 |
+
double rel = std::sqrt(l2d) / (std::sqrt(l2r) + 1e-10);
|
| 64 |
+
printf("L2 diff=%.4f ref=%.4f relative=%.4e max_abs=%.4f\n",
|
| 65 |
+
std::sqrt(l2d), std::sqrt(l2r), rel, maxd);
|
| 66 |
+
printf("y_cxx[0..3]: "); for (int i = 0; i < 4; i++) printf("%.3f ", bf16_to_float(y_cxx[i])); printf("\n");
|
| 67 |
+
printf("y_ref[0..3]: "); for (int i = 0; i < 4; i++) printf("%.3f ", bf16_to_float(y_ref[i])); printf("\n");
|
| 68 |
+
|
| 69 |
+
// BF16 matmul has more precision loss than RmsNorm. Allow 1% relative error.
|
| 70 |
+
bool ok = rel < 1e-2;
|
| 71 |
+
printf("\n%s\n", ok ? "=== test_linear_hf PASS ===" : "=== test_linear_hf FAIL ===");
|
| 72 |
+
return ok ? 0 : 1;
|
| 73 |
+
}
|
|
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// test_model_config.cpp — load config.json, derive TP shard sizes, verify all expected
|
| 2 |
+
// HF tensors exist in safetensors for Qwen3-235B.
|
| 3 |
+
#include "model_config.h"
|
| 4 |
+
#include "safetensors_loader.h"
|
| 5 |
+
|
| 6 |
+
#include <cstdio>
|
| 7 |
+
#include <sstream>
|
| 8 |
+
#include <string>
|
| 9 |
+
|
| 10 |
+
int main(int argc, char** argv) {
|
| 11 |
+
std::string dir = argc > 1 ? argv[1]
|
| 12 |
+
: "/path/to/Qwen3-235B-A22B-Instruct-2507-BF16";
|
| 13 |
+
int tp_size = argc > 2 ? std::atoi(argv[2]) : 16;
|
| 14 |
+
int tp_rank = argc > 3 ? std::atoi(argv[3]) : 0;
|
| 15 |
+
|
| 16 |
+
ModelConfig cfg;
|
| 17 |
+
if (!cfg.load_from_json(dir + "/config.json")) return 1;
|
| 18 |
+
cfg.compute_derived(tp_size, tp_rank);
|
| 19 |
+
printf("%s\n", cfg.describe().c_str());
|
| 20 |
+
|
| 21 |
+
SafetensorsLoader loader;
|
| 22 |
+
if (!loader.open(dir)) return 1;
|
| 23 |
+
|
| 24 |
+
// Verify all expected tensor names & shapes match cfg.
|
| 25 |
+
int missing = 0, shape_mismatch = 0;
|
| 26 |
+
auto check_shape = [&](const std::string& name, const std::vector<int64_t>& expected) {
|
| 27 |
+
auto* m = loader.get(name);
|
| 28 |
+
if (!m) {
|
| 29 |
+
printf(" MISSING: %s\n", name.c_str());
|
| 30 |
+
missing++;
|
| 31 |
+
return;
|
| 32 |
+
}
|
| 33 |
+
if (m->shape != expected) {
|
| 34 |
+
printf(" SHAPE MISMATCH: %s got=[", name.c_str());
|
| 35 |
+
for (size_t i = 0; i < m->shape.size(); i++) printf("%s%ld", i ? "," : "", m->shape[i]);
|
| 36 |
+
printf("] want=[");
|
| 37 |
+
for (size_t i = 0; i < expected.size(); i++) printf("%s%ld", i ? "," : "", expected[i]);
|
| 38 |
+
printf("]\n");
|
| 39 |
+
shape_mismatch++;
|
| 40 |
+
}
|
| 41 |
+
};
|
| 42 |
+
|
| 43 |
+
// embed/head
|
| 44 |
+
check_shape("model.embed_tokens.weight", {cfg.vocab_size, cfg.hidden_size});
|
| 45 |
+
check_shape("lm_head.weight", {cfg.vocab_size, cfg.hidden_size});
|
| 46 |
+
check_shape("model.norm.weight", {cfg.hidden_size});
|
| 47 |
+
|
| 48 |
+
// Attention weights (HF stores as [out, in])
|
| 49 |
+
int64_t q_full = cfg.num_attention_heads * cfg.head_dim;
|
| 50 |
+
int64_t kv_full = cfg.num_key_value_heads * cfg.head_dim;
|
| 51 |
+
for (int L = 0; L < cfg.num_hidden_layers; L++) {
|
| 52 |
+
auto base = "model.layers." + std::to_string(L);
|
| 53 |
+
check_shape(base + ".input_layernorm.weight", {cfg.hidden_size});
|
| 54 |
+
check_shape(base + ".post_attention_layernorm.weight", {cfg.hidden_size});
|
| 55 |
+
check_shape(base + ".self_attn.q_proj.weight", {q_full, cfg.hidden_size});
|
| 56 |
+
check_shape(base + ".self_attn.k_proj.weight", {kv_full, cfg.hidden_size});
|
| 57 |
+
check_shape(base + ".self_attn.v_proj.weight", {kv_full, cfg.hidden_size});
|
| 58 |
+
check_shape(base + ".self_attn.o_proj.weight", {cfg.hidden_size, q_full});
|
| 59 |
+
// Qwen3 uses q_norm / k_norm (norm per head) — check existence
|
| 60 |
+
check_shape(base + ".self_attn.q_norm.weight", {cfg.head_dim});
|
| 61 |
+
check_shape(base + ".self_attn.k_norm.weight", {cfg.head_dim});
|
| 62 |
+
// MoE router
|
| 63 |
+
check_shape(base + ".mlp.gate.weight", {cfg.num_experts, cfg.hidden_size});
|
| 64 |
+
// Spot-check few experts (full enumeration is 94*384=36096 lines)
|
| 65 |
+
for (int e : {0, 1, 63, 127}) {
|
| 66 |
+
auto ebase = base + ".mlp.experts." + std::to_string(e);
|
| 67 |
+
check_shape(ebase + ".gate_proj.weight", {cfg.moe_intermediate_size, cfg.hidden_size});
|
| 68 |
+
check_shape(ebase + ".up_proj.weight", {cfg.moe_intermediate_size, cfg.hidden_size});
|
| 69 |
+
check_shape(ebase + ".down_proj.weight", {cfg.hidden_size, cfg.moe_intermediate_size});
|
| 70 |
+
}
|
| 71 |
+
}
|
| 72 |
+
|
| 73 |
+
// Print TP memory estimate
|
| 74 |
+
int64_t attn_bytes_per_rank = 0;
|
| 75 |
+
attn_bytes_per_rank += cfg.q_dim_per_rank * cfg.hidden_size * 2; // q_proj
|
| 76 |
+
attn_bytes_per_rank += cfg.kv_dim_per_rank * cfg.hidden_size * 2; // k_proj
|
| 77 |
+
attn_bytes_per_rank += cfg.kv_dim_per_rank * cfg.hidden_size * 2; // v_proj
|
| 78 |
+
attn_bytes_per_rank += cfg.hidden_size * cfg.q_dim_per_rank * 2; // o_proj
|
| 79 |
+
attn_bytes_per_rank *= cfg.num_hidden_layers;
|
| 80 |
+
|
| 81 |
+
int64_t moe_bytes_per_rank = 0;
|
| 82 |
+
// gate_exps + up_exps: [E, I_per_rank, D]
|
| 83 |
+
moe_bytes_per_rank += 2 * cfg.num_experts * cfg.i_per_rank * cfg.hidden_size * 2;
|
| 84 |
+
// down_exps: [E, D, I_per_rank]
|
| 85 |
+
moe_bytes_per_rank += cfg.num_experts * cfg.hidden_size * cfg.i_per_rank * 2;
|
| 86 |
+
moe_bytes_per_rank *= cfg.num_hidden_layers;
|
| 87 |
+
|
| 88 |
+
int64_t embed_bytes = cfg.vocab_size * cfg.hidden_size * 2 * 2; // embed + lm_head
|
| 89 |
+
int64_t router_bytes = cfg.num_experts * cfg.hidden_size * 2 * cfg.num_hidden_layers;
|
| 90 |
+
int64_t norm_bytes = cfg.hidden_size * 2 * (2 * cfg.num_hidden_layers + 1);
|
| 91 |
+
int64_t total_per_rank = attn_bytes_per_rank + moe_bytes_per_rank + embed_bytes + router_bytes + norm_bytes;
|
| 92 |
+
|
| 93 |
+
printf("\nPer-rank weight memory estimate (BF16, TP=%d):\n", tp_size);
|
| 94 |
+
printf(" attention: %.2f GB\n", attn_bytes_per_rank / 1e9);
|
| 95 |
+
printf(" MoE exps: %.2f GB\n", moe_bytes_per_rank / 1e9);
|
| 96 |
+
printf(" embed+head: %.2f GB (replicated)\n", embed_bytes / 1e9);
|
| 97 |
+
printf(" router: %.2f MB (replicated)\n", router_bytes / 1e6);
|
| 98 |
+
printf(" norms: %.2f MB (replicated)\n", norm_bytes / 1e6);
|
| 99 |
+
printf(" TOTAL: %.2f GB\n", total_per_rank / 1e9);
|
| 100 |
+
|
| 101 |
+
int errors = missing + shape_mismatch;
|
| 102 |
+
printf("\nMissing: %d, Shape mismatch: %d\n", missing, shape_mismatch);
|
| 103 |
+
printf("%s\n", errors == 0 ? "=== test_model_config PASS ==="
|
| 104 |
+
: "=== test_model_config FAIL ===");
|
| 105 |
+
return errors == 0 ? 0 : 1;
|
| 106 |
+
}
|
|
@@ -0,0 +1,676 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// test_moe_layer.cpp — Full MoE layer forward (Qwen3-235B layer 0), TP=1.
|
| 2 |
+
//
|
| 3 |
+
// Pipeline:
|
| 4 |
+
// 1. Post-attention RmsNorm (input from attn_data/final_out.bin)
|
| 5 |
+
// 2. Router: xn @ W_router.T → logits [S, E]
|
| 6 |
+
// 3. TopK softmax → weights [S, K], expert_ids [S, K]
|
| 7 |
+
// 4. Host-normalize top_k weights (Qwen3 norm_topk_prob)
|
| 8 |
+
// 5. MoeInitRoutingV3 → expanded_x [S*K, D], expanded_row_idx, tokens_per_expert
|
| 9 |
+
// 6. GMM gate: expanded_x × gate_exps → [S*K, I]
|
| 10 |
+
// 7. GMM up: same → [S*K, I]
|
| 11 |
+
// 8. silu(gate) * up → [S*K, I]
|
| 12 |
+
// 9. GMM down: act × down_exps → [S*K, D]
|
| 13 |
+
// 10. MoeFinalizeRouting (weighted sum) → [S, D]
|
| 14 |
+
// 11. + residual
|
| 15 |
+
#include "acl_common.h"
|
| 16 |
+
#include "acl_runtime.h"
|
| 17 |
+
#include "aclnn_ops.h"
|
| 18 |
+
#include "device_weights.h"
|
| 19 |
+
#include "model_config.h"
|
| 20 |
+
#include "safetensors_loader.h"
|
| 21 |
+
|
| 22 |
+
#include <algorithm>
|
| 23 |
+
#include <cmath>
|
| 24 |
+
#include <cstdio>
|
| 25 |
+
#include <cstring>
|
| 26 |
+
#include <fstream>
|
| 27 |
+
#include <tuple>
|
| 28 |
+
#include <vector>
|
| 29 |
+
|
| 30 |
+
static float bf16_to_float(uint16_t x) {
|
| 31 |
+
uint32_t u = (uint32_t)x << 16; float f; std::memcpy(&f, &u, 4); return f;
|
| 32 |
+
}
|
| 33 |
+
static uint16_t float_to_bf16(float x) {
|
| 34 |
+
uint32_t u; std::memcpy(&u, &x, 4);
|
| 35 |
+
return (uint16_t)((u + 0x7FFF + ((u >> 16) & 1)) >> 16);
|
| 36 |
+
}
|
| 37 |
+
static std::vector<uint8_t> read_file(const std::string& p) {
|
| 38 |
+
std::ifstream f(p, std::ios::binary | std::ios::ate); size_t s = f.tellg();
|
| 39 |
+
f.seekg(0); std::vector<uint8_t> v(s); f.read((char*)v.data(), s); return v;
|
| 40 |
+
}
|
| 41 |
+
|
| 42 |
+
int main() {
|
| 43 |
+
const std::string model_dir = "/path/to/Qwen3-235B-A22B-Instruct-2507-BF16";
|
| 44 |
+
const std::string data_dir = "tests/moe_data";
|
| 45 |
+
|
| 46 |
+
ModelConfig cfg;
|
| 47 |
+
if (!cfg.load_from_json(model_dir + "/config.json")) return 1;
|
| 48 |
+
cfg.compute_derived(1, 0); // TP=1
|
| 49 |
+
const int64_t D = cfg.hidden_size;
|
| 50 |
+
const int64_t I = cfg.moe_intermediate_size;
|
| 51 |
+
const int64_t E = cfg.num_experts;
|
| 52 |
+
const int64_t K = cfg.num_experts_per_tok;
|
| 53 |
+
const double eps = cfg.rms_norm_eps;
|
| 54 |
+
|
| 55 |
+
AclRuntime rt;
|
| 56 |
+
rt.init(0);
|
| 57 |
+
printf("[dbg] rt init ok\n"); fflush(stdout);
|
| 58 |
+
|
| 59 |
+
SafetensorsLoader st;
|
| 60 |
+
if (!st.open(model_dir)) return 1;
|
| 61 |
+
|
| 62 |
+
// ---- Load weights ----
|
| 63 |
+
printf("Loading layer 0 attention weights (for post_attention_layernorm)...\n");
|
| 64 |
+
DeviceWeightsLoader dw(st, cfg);
|
| 65 |
+
LayerAttnWeights attn;
|
| 66 |
+
if (!dw.load_attention(0, attn)) return 1;
|
| 67 |
+
|
| 68 |
+
printf("Loading layer 0 MoE weights (128 experts × 3 projections, stacking + permute)...\n"); fflush(stdout);
|
| 69 |
+
LayerMoEWeights moe;
|
| 70 |
+
if (!dw.load_moe(0, rt.stream(), moe)) return 1;
|
| 71 |
+
rt.sync();
|
| 72 |
+
printf("[dbg] moe load ok\n"); fflush(stdout);
|
| 73 |
+
printf(" router %.1f MB gate_exps %.0f MB up_exps %.0f MB down_exps %.0f MB\n",
|
| 74 |
+
moe.router.size / 1e6, moe.gate_exps.size / 1e6, moe.up_exps.size / 1e6, moe.down_exps.size / 1e6);
|
| 75 |
+
|
| 76 |
+
// ---- Load input & Python reference ----
|
| 77 |
+
int S = 5;
|
| 78 |
+
auto x_in_host = read_file(data_dir + "/x_in.bin");
|
| 79 |
+
auto ref_out_host = read_file(data_dir + "/final_out.bin");
|
| 80 |
+
DeviceBuffer x_dev(S * D * 2);
|
| 81 |
+
ACL_CHECK(aclrtMemcpy(x_dev.get(), x_in_host.size(), x_in_host.data(), x_in_host.size(), ACL_MEMCPY_HOST_TO_DEVICE));
|
| 82 |
+
|
| 83 |
+
// Residual snapshot
|
| 84 |
+
DeviceBuffer residual_dev(S * D * 2);
|
| 85 |
+
ACL_CHECK(aclrtMemcpy(residual_dev.get(), S*D*2, x_dev.get(), S*D*2, ACL_MEMCPY_DEVICE_TO_DEVICE));
|
| 86 |
+
|
| 87 |
+
printf("[dbg] loaded data and residual ok, TOTAL=%ld\n", S * K); fflush(stdout);
|
| 88 |
+
|
| 89 |
+
// ---- Step 1: Post-attention RmsNorm ----
|
| 90 |
+
DeviceBuffer xn_dev(S * D * 2);
|
| 91 |
+
DeviceBuffer rstd_dev(S * 4);
|
| 92 |
+
auto t_x = make_contig_tensor(x_dev.get(), ACL_BF16, {S, D});
|
| 93 |
+
auto t_xn = make_contig_tensor(xn_dev.get(), ACL_BF16, {S, D});
|
| 94 |
+
auto t_ln = make_contig_tensor(attn.post_attention_layernorm.get(), ACL_BF16, {D});
|
| 95 |
+
auto t_rstd = make_contig_tensor(rstd_dev.get(), ACL_FLOAT, {S});
|
| 96 |
+
rms_norm(rt.stream(), t_x.get(), t_ln.get(), eps, t_xn.get(), t_rstd.get());
|
| 97 |
+
rt.sync();
|
| 98 |
+
printf("[dbg] rms_norm ok\n"); fflush(stdout);
|
| 99 |
+
|
| 100 |
+
// ---- Step 2: Router (gate matmul) ----
|
| 101 |
+
DeviceBuffer logits_dev(S * E * 2);
|
| 102 |
+
auto t_logits = make_contig_tensor(logits_dev.get(), ACL_BF16, {S, E});
|
| 103 |
+
// router is [E, D] (HF). logits = xn @ router.T
|
| 104 |
+
linear_hf(rt.stream(), t_xn.get(), moe.router.get(), ACL_BF16, E, D, t_logits.get());
|
| 105 |
+
rt.sync();
|
| 106 |
+
printf("[dbg] router linear ok\n"); fflush(stdout);
|
| 107 |
+
|
| 108 |
+
// ---- Step 3: TopK softmax ----
|
| 109 |
+
DeviceBuffer topk_w_dev(S * K * 2); // BF16
|
| 110 |
+
DeviceBuffer topk_idx_dev(S * K * 4); // int32
|
| 111 |
+
DeviceBuffer row_idx_dev(S * K * 4); // int32 (from gating op, unused for our routing)
|
| 112 |
+
auto t_topk_w = make_contig_tensor(topk_w_dev.get(), ACL_BF16, {S, K});
|
| 113 |
+
auto t_topk_idx = make_contig_tensor(topk_idx_dev.get(), ACL_INT32, {S, K});
|
| 114 |
+
auto t_row_idx = make_contig_tensor(row_idx_dev.get(), ACL_INT32, {S, K});
|
| 115 |
+
moe_gating_topk_softmax(rt.stream(), t_logits.get(), K, t_topk_w.get(), t_topk_idx.get(), t_row_idx.get());
|
| 116 |
+
rt.sync();
|
| 117 |
+
printf("[dbg] topk_softmax ok\n"); fflush(stdout);
|
| 118 |
+
|
| 119 |
+
// ---- Step 4: Host-normalize top_k weights (norm_topk_prob=true) ----
|
| 120 |
+
std::vector<uint16_t> tw_bf(S * K);
|
| 121 |
+
ACL_CHECK(aclrtMemcpy(tw_bf.data(), S*K*2, topk_w_dev.get(), S*K*2, ACL_MEMCPY_DEVICE_TO_HOST));
|
| 122 |
+
for (int s = 0; s < S; s++) {
|
| 123 |
+
float sum = 0.0f;
|
| 124 |
+
for (int k = 0; k < K; k++) sum += bf16_to_float(tw_bf[s*K + k]);
|
| 125 |
+
sum += 1e-20f;
|
| 126 |
+
for (int k = 0; k < K; k++) {
|
| 127 |
+
float v = bf16_to_float(tw_bf[s*K + k]) / sum;
|
| 128 |
+
tw_bf[s*K + k] = float_to_bf16(v);
|
| 129 |
+
}
|
| 130 |
+
}
|
| 131 |
+
ACL_CHECK(aclrtMemcpy(topk_w_dev.get(), S*K*2, tw_bf.data(), S*K*2, ACL_MEMCPY_HOST_TO_DEVICE));
|
| 132 |
+
|
| 133 |
+
// ---- Step 5: MoE init routing ----
|
| 134 |
+
int64_t TOTAL = S * K;
|
| 135 |
+
DeviceBuffer expanded_x_dev(TOTAL * D * 2);
|
| 136 |
+
DeviceBuffer expanded_row_idx_dev(TOTAL * 4);
|
| 137 |
+
DeviceBuffer tokens_per_expert_dev(E * 8);
|
| 138 |
+
|
| 139 |
+
auto t_ex_x = make_contig_tensor(expanded_x_dev.get(), ACL_BF16, {TOTAL, D});
|
| 140 |
+
auto t_ex_ri = make_contig_tensor(expanded_row_idx_dev.get(), ACL_INT32, {TOTAL});
|
| 141 |
+
auto t_tpe = make_contig_tensor(tokens_per_expert_dev.get(), ACL_INT64, {E});
|
| 142 |
+
|
| 143 |
+
moe_init_routing_v3(rt.stream(),
|
| 144 |
+
t_xn.get(), t_topk_idx.get(),
|
| 145 |
+
E, TOTAL,
|
| 146 |
+
t_ex_x.get(), t_ex_ri.get(), t_tpe.get());
|
| 147 |
+
rt.sync();
|
| 148 |
+
printf("[dbg] moe_init_routing ok\n"); fflush(stdout);
|
| 149 |
+
|
| 150 |
+
// Convert tokens_per_expert from counts to cumsum (on host) for GMM groupListType=0.
|
| 151 |
+
DeviceBuffer tpe_cumsum_dev(E * 8);
|
| 152 |
+
{
|
| 153 |
+
std::vector<int64_t> h_counts(E), h_cum(E);
|
| 154 |
+
ACL_CHECK(aclrtMemcpy(h_counts.data(), E*8, tokens_per_expert_dev.get(), E*8, ACL_MEMCPY_DEVICE_TO_HOST));
|
| 155 |
+
int64_t acc = 0;
|
| 156 |
+
for (int i = 0; i < E; i++) { acc += h_counts[i]; h_cum[i] = acc; }
|
| 157 |
+
ACL_CHECK(aclrtMemcpy(tpe_cumsum_dev.get(), E*8, h_cum.data(), E*8, ACL_MEMCPY_HOST_TO_DEVICE));
|
| 158 |
+
}
|
| 159 |
+
auto t_tpe_cum = make_contig_tensor(tpe_cumsum_dev.get(), ACL_INT64, {E});
|
| 160 |
+
|
| 161 |
+
// ---- Step 6/7: GMM gate and up ----
|
| 162 |
+
DeviceBuffer gate_out_dev(TOTAL * I * 2);
|
| 163 |
+
DeviceBuffer up_out_dev(TOTAL * I * 2);
|
| 164 |
+
auto t_gate_out = make_contig_tensor(gate_out_dev.get(), ACL_BF16, {TOTAL, I});
|
| 165 |
+
auto t_up_out = make_contig_tensor(up_out_dev.get(), ACL_BF16, {TOTAL, I});
|
| 166 |
+
// gate/up_exps loaded as [E, D, I] row-major
|
| 167 |
+
auto t_w_gate = make_contig_tensor(moe.gate_exps.get(), ACL_BF16, {E, D, I});
|
| 168 |
+
auto t_w_up = make_contig_tensor(moe.up_exps.get(), ACL_BF16, {E, D, I});
|
| 169 |
+
// Use cumsum group_list (groupListType=0): empirically more reliable with many zero-count experts.
|
| 170 |
+
grouped_matmul_v4(rt.stream(), t_ex_x.get(), t_w_gate.get(), t_tpe_cum.get(), t_gate_out.get(), 0);
|
| 171 |
+
rt.sync();
|
| 172 |
+
printf("[dbg] gmm gate ok\n"); fflush(stdout);
|
| 173 |
+
grouped_matmul_v4(rt.stream(), t_ex_x.get(), t_w_up.get(), t_tpe_cum.get(), t_up_out.get(), 0);
|
| 174 |
+
rt.sync();
|
| 175 |
+
printf("[dbg] gmm up ok\n"); fflush(stdout);
|
| 176 |
+
|
| 177 |
+
// ---- Step 8: SwiGLU ----
|
| 178 |
+
// act = silu(gate) * up (inplace on gate_out)
|
| 179 |
+
silu(rt.stream(), t_gate_out.get(), t_gate_out.get());
|
| 180 |
+
rt.sync(); printf("[dbg] silu ok\n"); fflush(stdout);
|
| 181 |
+
mul(rt.stream(), t_gate_out.get(), t_up_out.get(), t_gate_out.get());
|
| 182 |
+
rt.sync(); printf("[dbg] mul ok\n"); fflush(stdout);
|
| 183 |
+
// now gate_out_dev contains the activated intermediate
|
| 184 |
+
|
| 185 |
+
// ---- Step 9: GMM down ----
|
| 186 |
+
DeviceBuffer down_out_dev(TOTAL * D * 2);
|
| 187 |
+
auto t_down_out = make_contig_tensor(down_out_dev.get(), ACL_BF16, {TOTAL, D});
|
| 188 |
+
auto t_w_down = make_contig_tensor(moe.down_exps.get(), ACL_BF16, {E, I, D});
|
| 189 |
+
grouped_matmul_v4(rt.stream(), t_gate_out.get(), t_w_down.get(), t_tpe_cum.get(), t_down_out.get(), 0);
|
| 190 |
+
rt.sync();
|
| 191 |
+
printf("[dbg] gmm down ok\n"); fflush(stdout);
|
| 192 |
+
|
| 193 |
+
// ---- Step 10: Device-side manual finalize (replacement for buggy MoeFinalizeRoutingV2) ----
|
| 194 |
+
// Compute forward permutation fwd[n*K + k] = p where token n's k-th expert's output is at
|
| 195 |
+
// expanded position p. We use tokens_per_expert (cumsum) + topk_idx to resolve this correctly,
|
| 196 |
+
// regardless of the exact rowIdxType semantics returned by MoeInitRoutingV3.
|
| 197 |
+
DeviceBuffer fwd_dev(TOTAL * 8);
|
| 198 |
+
{
|
| 199 |
+
std::vector<int64_t> h_tpe2(E);
|
| 200 |
+
std::vector<int32_t> h_tidx3(S * K);
|
| 201 |
+
ACL_CHECK(aclrtMemcpy(h_tpe2.data(), E*8, tokens_per_expert_dev.get(), E*8, ACL_MEMCPY_DEVICE_TO_HOST));
|
| 202 |
+
ACL_CHECK(aclrtMemcpy(h_tidx3.data(), S*K*4, topk_idx_dev.get(), S*K*4, ACL_MEMCPY_DEVICE_TO_HOST));
|
| 203 |
+
|
| 204 |
+
// Sort (n, k) pairs by expert ascending (stable). For each expert in order, tokens
|
| 205 |
+
// appear in ascending token index (since MoeInitRoutingV3 is stable by s).
|
| 206 |
+
// Specifically: expanded positions 0..tpe[0]-1 are for expert 0 (tokens picking e=0, in n-ascending order),
|
| 207 |
+
// next tpe[1] are for expert 1, etc.
|
| 208 |
+
//
|
| 209 |
+
// To build fwd: for each (n, k), expert e = topk_idx[n, k]. Position p is the base of expert e's
|
| 210 |
+
// block plus the rank of n within tokens picking e.
|
| 211 |
+
std::vector<int64_t> expert_base(E + 1, 0);
|
| 212 |
+
for (int e = 0; e < E; e++) expert_base[e + 1] = expert_base[e] + h_tpe2[e];
|
| 213 |
+
|
| 214 |
+
std::vector<int> expert_slot(E, 0); // next available slot per expert
|
| 215 |
+
std::vector<int64_t> fwd(TOTAL);
|
| 216 |
+
// Iterate in token-ascending, k-ascending order — match MoeInitRoutingV3's stable sort convention.
|
| 217 |
+
// For each (n, k) sorted by (expert[n,k], n), assign p.
|
| 218 |
+
// Simpler: pre-collect (e, n, k) triples, sort by (e, n), then p is the rank.
|
| 219 |
+
std::vector<std::tuple<int, int, int>> triples;
|
| 220 |
+
triples.reserve(TOTAL);
|
| 221 |
+
for (int n = 0; n < S; n++) for (int k = 0; k < K; k++) {
|
| 222 |
+
triples.emplace_back(h_tidx3[n * K + k], n, k);
|
| 223 |
+
}
|
| 224 |
+
std::sort(triples.begin(), triples.end(), [](const auto& a, const auto& b){
|
| 225 |
+
if (std::get<0>(a) != std::get<0>(b)) return std::get<0>(a) < std::get<0>(b);
|
| 226 |
+
return std::get<1>(a) < std::get<1>(b);
|
| 227 |
+
});
|
| 228 |
+
for (int64_t p = 0; p < TOTAL; p++) {
|
| 229 |
+
auto [e, n, k] = triples[p];
|
| 230 |
+
fwd[n * K + k] = p;
|
| 231 |
+
}
|
| 232 |
+
ACL_CHECK(aclrtMemcpy(fwd_dev.get(), TOTAL*8, fwd.data(), TOTAL*8, ACL_MEMCPY_HOST_TO_DEVICE));
|
| 233 |
+
}
|
| 234 |
+
auto t_fwd = make_contig_tensor(fwd_dev.get(), ACL_INT64, {TOTAL});
|
| 235 |
+
|
| 236 |
+
// Gather: packed [S*K, D] = down_out[fwd, :]
|
| 237 |
+
DeviceBuffer packed_dev(TOTAL * D * 2);
|
| 238 |
+
auto t_packed = make_contig_tensor(packed_dev.get(), ACL_BF16, {TOTAL, D});
|
| 239 |
+
index_select(rt.stream(), t_down_out.get(), 0, t_fwd.get(), t_packed.get());
|
| 240 |
+
rt.sync();
|
| 241 |
+
|
| 242 |
+
// Broadcast-multiply by topk_w: view packed as [S, K, D], topk_w as [S, K, 1].
|
| 243 |
+
auto t_packed_3d = make_contig_tensor(packed_dev.get(), ACL_BF16, {S, K, D});
|
| 244 |
+
auto t_topk_w_3d = make_contig_tensor(topk_w_dev.get(), ACL_BF16, {S, K, 1});
|
| 245 |
+
DeviceBuffer weighted_dev(S * K * D * 2);
|
| 246 |
+
auto t_weighted = make_contig_tensor(weighted_dev.get(), ACL_BF16, {S, K, D});
|
| 247 |
+
mul(rt.stream(), t_packed_3d.get(), t_topk_w_3d.get(), t_weighted.get());
|
| 248 |
+
rt.sync();
|
| 249 |
+
|
| 250 |
+
// Verify broadcast mul + sum by dumping all k entries and summing on host.
|
| 251 |
+
{
|
| 252 |
+
std::vector<uint16_t> h_pk_all(S * K * D);
|
| 253 |
+
std::vector<uint16_t> h_wt_all(S * K * D);
|
| 254 |
+
std::vector<uint16_t> h_tw_all(S * K);
|
| 255 |
+
ACL_CHECK(aclrtMemcpy(h_pk_all.data(), S*K*D*2, packed_dev.get(), S*K*D*2, ACL_MEMCPY_DEVICE_TO_HOST));
|
| 256 |
+
ACL_CHECK(aclrtMemcpy(h_wt_all.data(), S*K*D*2, weighted_dev.get(), S*K*D*2, ACL_MEMCPY_DEVICE_TO_HOST));
|
| 257 |
+
ACL_CHECK(aclrtMemcpy(h_tw_all.data(), S*K*2, topk_w_dev.get(), S*K*2, ACL_MEMCPY_DEVICE_TO_HOST));
|
| 258 |
+
|
| 259 |
+
printf(" verify weighted[0, k, 0] = packed[0, k, 0] * topk_w[0, k] for all k:\n");
|
| 260 |
+
float host_sum = 0;
|
| 261 |
+
for (int k = 0; k < K; k++) {
|
| 262 |
+
float p = bf16_to_float(h_pk_all[k * D]); // packed[0, k, 0] = offset s*K*D + k*D + 0 = k*D (for s=0)
|
| 263 |
+
float w = bf16_to_float(h_tw_all[k]); // topk_w[0, k]
|
| 264 |
+
float wt = bf16_to_float(h_wt_all[k * D]); // weighted[0, k, 0]
|
| 265 |
+
host_sum += p * w;
|
| 266 |
+
printf(" k=%d: packed=%.5f * topk_w=%.5f = expect=%.5f dev=%.5f\n",
|
| 267 |
+
k, p, w, p*w, wt);
|
| 268 |
+
}
|
| 269 |
+
printf(" host_sum_of_weighted[0, :, 0] = %.5f (expected moe_out[0,0] = -0.02466)\n", host_sum);
|
| 270 |
+
}
|
| 271 |
+
|
| 272 |
+
// ReduceSum over K axis → [S, D]
|
| 273 |
+
DeviceBuffer moe_out_dev(S * D * 2);
|
| 274 |
+
auto t_moe_out = make_contig_tensor(moe_out_dev.get(), ACL_BF16, {S, D});
|
| 275 |
+
reduce_sum(rt.stream(), t_weighted.get(), {1}, /*keep_dims=*/false, ACL_BF16, t_moe_out.get());
|
| 276 |
+
rt.sync();
|
| 277 |
+
printf("[dbg] device-side finalize (gather+mul+reduce) ok\n"); fflush(stdout);
|
| 278 |
+
|
| 279 |
+
// Residual add to produce final_out
|
| 280 |
+
float alpha_v = 1.0f; aclScalar* alpha = aclCreateScalar(&alpha_v, ACL_FLOAT);
|
| 281 |
+
DeviceBuffer final_dev(S * D * 2);
|
| 282 |
+
auto t_final = make_contig_tensor(final_dev.get(), ACL_BF16, {S, D});
|
| 283 |
+
auto t_res = make_contig_tensor(residual_dev.get(), ACL_BF16, {S, D});
|
| 284 |
+
{
|
| 285 |
+
uint64_t ws = 0; aclOpExecutor* e = nullptr;
|
| 286 |
+
ACLNN_CHECK(aclnnAddGetWorkspaceSize(t_res.get(), t_moe_out.get(), alpha, t_final.get(), &ws, &e));
|
| 287 |
+
DeviceBuffer wb; if (ws > 0) wb.alloc(ws);
|
| 288 |
+
ACLNN_CHECK(aclnnAdd(wb.get(), ws, e, rt.stream()));
|
| 289 |
+
}
|
| 290 |
+
aclDestroyScalar(alpha);
|
| 291 |
+
rt.sync();
|
| 292 |
+
|
| 293 |
+
// ---- Compare (intermediate + final) ----
|
| 294 |
+
auto compare_bf16 = [&](const char* label, void* dev_ptr, int64_t nelem,
|
| 295 |
+
const std::string& ref_file) {
|
| 296 |
+
std::vector<uint16_t> cxx(nelem);
|
| 297 |
+
ACL_CHECK(aclrtMemcpy(cxx.data(), nelem*2, dev_ptr, nelem*2, ACL_MEMCPY_DEVICE_TO_HOST));
|
| 298 |
+
auto refbuf = read_file(data_dir + "/" + ref_file);
|
| 299 |
+
auto* ref = (const uint16_t*)refbuf.data();
|
| 300 |
+
double l2d = 0, l2r = 0, maxd = 0;
|
| 301 |
+
for (int64_t i = 0; i < nelem; i++) {
|
| 302 |
+
float a = bf16_to_float(cxx[i]), b = bf16_to_float(ref[i]);
|
| 303 |
+
l2d += (a-b)*(a-b); l2r += b*b;
|
| 304 |
+
if (std::abs(a-b) > maxd) maxd = std::abs(a-b);
|
| 305 |
+
}
|
| 306 |
+
double rel = std::sqrt(l2d) / (std::sqrt(l2r) + 1e-10);
|
| 307 |
+
printf(" [cmp] %-12s rel=%.4e max_abs=%.4f cxx[:4]=%.5f %.5f %.5f %.5f ref[:4]=%.5f %.5f %.5f %.5f\n",
|
| 308 |
+
label, rel, maxd,
|
| 309 |
+
bf16_to_float(cxx[0]), bf16_to_float(cxx[1]), bf16_to_float(cxx[2]), bf16_to_float(cxx[3]),
|
| 310 |
+
bf16_to_float(ref[0]), bf16_to_float(ref[1]), bf16_to_float(ref[2]), bf16_to_float(ref[3]));
|
| 311 |
+
return rel;
|
| 312 |
+
};
|
| 313 |
+
|
| 314 |
+
printf("\n=== Intermediate diagnostics ===\n");
|
| 315 |
+
compare_bf16("xn", xn_dev.get(), S * D, "xn.bin");
|
| 316 |
+
compare_bf16("topk_w", topk_w_dev.get(), S * K, "topk_w.bin");
|
| 317 |
+
|
| 318 |
+
// Dump topk_idx (int32) to compare
|
| 319 |
+
{
|
| 320 |
+
std::vector<int32_t> cxx_idx(S*K);
|
| 321 |
+
ACL_CHECK(aclrtMemcpy(cxx_idx.data(), S*K*4, topk_idx_dev.get(), S*K*4, ACL_MEMCPY_DEVICE_TO_HOST));
|
| 322 |
+
auto refbuf = read_file(data_dir + "/topk_idx.bin");
|
| 323 |
+
auto* ref = (const int32_t*)refbuf.data();
|
| 324 |
+
int mismatches = 0;
|
| 325 |
+
for (int i = 0; i < S*K; i++) if (cxx_idx[i] != ref[i]) mismatches++;
|
| 326 |
+
printf(" [cmp] topk_idx mismatches=%d/%d cxx[0,:4]=%d %d %d %d ref[0,:4]=%d %d %d %d\n",
|
| 327 |
+
mismatches, S*K,
|
| 328 |
+
cxx_idx[0], cxx_idx[1], cxx_idx[2], cxx_idx[3],
|
| 329 |
+
ref[0], ref[1], ref[2], ref[3]);
|
| 330 |
+
}
|
| 331 |
+
|
| 332 |
+
printf("\n=== MoE-only (before residual) ===\n");
|
| 333 |
+
compare_bf16("moe_out", moe_out_dev.get(), S * D, "out_flat.bin");
|
| 334 |
+
|
| 335 |
+
// Manual host-side finalize: verify what down_out + expanded_row_idx + topk_w produce.
|
| 336 |
+
{
|
| 337 |
+
std::vector<uint16_t> h_down(TOTAL * D);
|
| 338 |
+
std::vector<int32_t> h_ri(TOTAL);
|
| 339 |
+
std::vector<uint16_t> h_tw(S * K);
|
| 340 |
+
ACL_CHECK(aclrtMemcpy(h_down.data(), TOTAL*D*2, down_out_dev.get(), TOTAL*D*2, ACL_MEMCPY_DEVICE_TO_HOST));
|
| 341 |
+
ACL_CHECK(aclrtMemcpy(h_ri.data(), TOTAL*4, expanded_row_idx_dev.get(), TOTAL*4, ACL_MEMCPY_DEVICE_TO_HOST));
|
| 342 |
+
ACL_CHECK(aclrtMemcpy(h_tw.data(), S*K*2, topk_w_dev.get(), S*K*2, ACL_MEMCPY_DEVICE_TO_HOST));
|
| 343 |
+
|
| 344 |
+
printf(" expanded_row_idx (all %ld):\n ", TOTAL);
|
| 345 |
+
for (int i = 0; i < TOTAL; i++) {
|
| 346 |
+
printf("%d ", h_ri[i]);
|
| 347 |
+
if ((i+1) % 10 == 0) printf("\n ");
|
| 348 |
+
}
|
| 349 |
+
printf("\n");
|
| 350 |
+
// count unique and check bijection
|
| 351 |
+
std::vector<int> count(TOTAL, 0);
|
| 352 |
+
int out_of_range = 0;
|
| 353 |
+
for (int i = 0; i < TOTAL; i++) {
|
| 354 |
+
int v = h_ri[i];
|
| 355 |
+
if (v >= 0 && v < TOTAL) count[v]++;
|
| 356 |
+
else out_of_range++;
|
| 357 |
+
}
|
| 358 |
+
int bijection_ok = (out_of_range == 0);
|
| 359 |
+
for (int i = 0; i < TOTAL && bijection_ok; i++) if (count[i] != 1) bijection_ok = 0;
|
| 360 |
+
printf(" bijection=%s out_of_range=%d\n", bijection_ok ? "YES" : "NO", out_of_range);
|
| 361 |
+
|
| 362 |
+
// Also dump tokens_per_expert (int64) — should sum to TOTAL
|
| 363 |
+
std::vector<int64_t> h_tpe(E);
|
| 364 |
+
ACL_CHECK(aclrtMemcpy(h_tpe.data(), E*8, tokens_per_expert_dev.get(), E*8, ACL_MEMCPY_DEVICE_TO_HOST));
|
| 365 |
+
int64_t tpe_sum = 0, nonzero = 0;
|
| 366 |
+
int64_t tpe_max = 0;
|
| 367 |
+
for (int i = 0; i < E; i++) { tpe_sum += h_tpe[i]; if (h_tpe[i]>0) nonzero++; if (h_tpe[i]>tpe_max) tpe_max=h_tpe[i]; }
|
| 368 |
+
printf(" tokens_per_expert: sum=%ld nonzero=%ld max=%ld (expected sum=%ld if counts, or last=%ld if cumsum)\n",
|
| 369 |
+
tpe_sum, nonzero, tpe_max, TOTAL, TOTAL);
|
| 370 |
+
printf(" tpe[last 4]: %ld %ld %ld %ld\n", h_tpe[E-4], h_tpe[E-3], h_tpe[E-2], h_tpe[E-1]);
|
| 371 |
+
|
| 372 |
+
std::vector<float> manual(S * D, 0.0f);
|
| 373 |
+
for (int64_t p = 0; p < TOTAL; p++) {
|
| 374 |
+
int32_t src = h_ri[p];
|
| 375 |
+
int s = src / K;
|
| 376 |
+
int k = src % K;
|
| 377 |
+
if (s < 0 || s >= S || k < 0 || k >= K) { printf(" bad idx p=%ld src=%d\n", p, src); continue; }
|
| 378 |
+
float w = bf16_to_float(h_tw[s * K + k]);
|
| 379 |
+
for (int d = 0; d < D; d++) {
|
| 380 |
+
manual[s * D + d] += w * bf16_to_float(h_down[p * D + d]);
|
| 381 |
+
}
|
| 382 |
+
}
|
| 383 |
+
// Convert to bf16 and compare to Python out_flat
|
| 384 |
+
auto refbuf = read_file(data_dir + "/out_flat.bin");
|
| 385 |
+
auto* ref = (const uint16_t*)refbuf.data();
|
| 386 |
+
double l2d=0, l2r=0, maxd=0;
|
| 387 |
+
for (int64_t i = 0; i < S*D; i++) {
|
| 388 |
+
float a = manual[i], b = bf16_to_float(ref[i]);
|
| 389 |
+
l2d += (a-b)*(a-b); l2r += b*b;
|
| 390 |
+
if (std::abs(a-b) > maxd) maxd = std::abs(a-b);
|
| 391 |
+
}
|
| 392 |
+
double rel_manual = std::sqrt(l2d) / (std::sqrt(l2r) + 1e-10);
|
| 393 |
+
printf(" [cmp] MANUAL(row_idx=src→flat) rel=%.4e max_abs=%.4f m[:4]=%.5f %.5f %.5f %.5f r[:4]=%.5f %.5f %.5f %.5f\n",
|
| 394 |
+
rel_manual, maxd,
|
| 395 |
+
manual[0], manual[1], manual[2], manual[3],
|
| 396 |
+
bf16_to_float(ref[0]), bf16_to_float(ref[1]), bf16_to_float(ref[2]), bf16_to_float(ref[3]));
|
| 397 |
+
|
| 398 |
+
// Alternative semantic: row_idx[p] = destination position
|
| 399 |
+
// In that case: p=src_row, dst=h_ri[p]
|
| 400 |
+
std::vector<float> manual2(S * D, 0.0f);
|
| 401 |
+
for (int64_t p = 0; p < TOTAL; p++) {
|
| 402 |
+
int32_t dst = h_ri[p];
|
| 403 |
+
int s = dst / K;
|
| 404 |
+
int k = dst % K;
|
| 405 |
+
if (s < 0 || s >= S || k < 0 || k >= K) continue;
|
| 406 |
+
float w = bf16_to_float(h_tw[s * K + k]);
|
| 407 |
+
for (int d = 0; d < D; d++) {
|
| 408 |
+
manual2[s * D + d] += w * bf16_to_float(h_down[p * D + d]);
|
| 409 |
+
}
|
| 410 |
+
}
|
| 411 |
+
double l2d2=0, l2r2=0, maxd2=0;
|
| 412 |
+
for (int64_t i = 0; i < S*D; i++) {
|
| 413 |
+
float a = manual2[i], b = bf16_to_float(ref[i]);
|
| 414 |
+
l2d2 += (a-b)*(a-b); l2r2 += b*b;
|
| 415 |
+
if (std::abs(a-b) > maxd2) maxd2 = std::abs(a-b);
|
| 416 |
+
}
|
| 417 |
+
double rel_manual2 = std::sqrt(l2d2) / (std::sqrt(l2r2) + 1e-10);
|
| 418 |
+
printf(" [cmp] MANUAL(row_idx=p→dst_flat) rel=%.4e max_abs=%.4f m[:4]=%.5f %.5f %.5f %.5f\n",
|
| 419 |
+
rel_manual2, maxd2,
|
| 420 |
+
manual2[0], manual2[1], manual2[2], manual2[3]);
|
| 421 |
+
}
|
| 422 |
+
|
| 423 |
+
// Manual finalize using cumsum (semantics-independent):
|
| 424 |
+
// For each (n, k), find p such that actual_s(p)=n AND expert(p)=topk_idx[n,k], then
|
| 425 |
+
// out[n] += topk_w[n,k] * down_out[p].
|
| 426 |
+
{
|
| 427 |
+
std::vector<uint16_t> h_down(TOTAL * D);
|
| 428 |
+
std::vector<int64_t> h_tpe(E);
|
| 429 |
+
std::vector<int32_t> h_tidx(S * K);
|
| 430 |
+
std::vector<uint16_t> h_tw(S * K);
|
| 431 |
+
std::vector<uint16_t> h_xn_all(S * D);
|
| 432 |
+
std::vector<uint16_t> h_ex_all(TOTAL * D);
|
| 433 |
+
ACL_CHECK(aclrtMemcpy(h_down.data(), TOTAL*D*2, down_out_dev.get(), TOTAL*D*2, ACL_MEMCPY_DEVICE_TO_HOST));
|
| 434 |
+
ACL_CHECK(aclrtMemcpy(h_tpe.data(), E*8, tokens_per_expert_dev.get(), E*8, ACL_MEMCPY_DEVICE_TO_HOST));
|
| 435 |
+
ACL_CHECK(aclrtMemcpy(h_tidx.data(), S*K*4, topk_idx_dev.get(), S*K*4, ACL_MEMCPY_DEVICE_TO_HOST));
|
| 436 |
+
ACL_CHECK(aclrtMemcpy(h_tw.data(), S*K*2, topk_w_dev.get(), S*K*2, ACL_MEMCPY_DEVICE_TO_HOST));
|
| 437 |
+
ACL_CHECK(aclrtMemcpy(h_xn_all.data(), S*D*2, xn_dev.get(), S*D*2, ACL_MEMCPY_DEVICE_TO_HOST));
|
| 438 |
+
ACL_CHECK(aclrtMemcpy(h_ex_all.data(), TOTAL*D*2, expanded_x_dev.get(), TOTAL*D*2, ACL_MEMCPY_DEVICE_TO_HOST));
|
| 439 |
+
|
| 440 |
+
// Build p → (actual_s, actual_expert).
|
| 441 |
+
// actual_s: find s with xn[s,0] == expanded_x[p,0]
|
| 442 |
+
// actual_expert: find e such that cumsum_tpe[e-1] <= p < cumsum_tpe[e]
|
| 443 |
+
std::vector<int> p_to_s(TOTAL), p_to_e(TOTAL);
|
| 444 |
+
int64_t cum = 0;
|
| 445 |
+
int cursor_e = 0;
|
| 446 |
+
for (int64_t p = 0; p < TOTAL; p++) {
|
| 447 |
+
while (cursor_e < E && p >= cum + h_tpe[cursor_e]) { cum += h_tpe[cursor_e]; cursor_e++; }
|
| 448 |
+
p_to_e[p] = cursor_e;
|
| 449 |
+
float ev = bf16_to_float(h_ex_all[p * D]);
|
| 450 |
+
int best = -1; float bd = 1e30f;
|
| 451 |
+
for (int s = 0; s < S; s++) {
|
| 452 |
+
float df = std::abs(bf16_to_float(h_xn_all[s * D]) - ev);
|
| 453 |
+
if (df < bd) { bd = df; best = s; }
|
| 454 |
+
}
|
| 455 |
+
p_to_s[p] = best;
|
| 456 |
+
}
|
| 457 |
+
|
| 458 |
+
// Build (n, k) → p lookup via (n, expert) → p
|
| 459 |
+
std::vector<float> manual_cum(S * D, 0.0f);
|
| 460 |
+
int found_count = 0;
|
| 461 |
+
for (int n = 0; n < S; n++) {
|
| 462 |
+
for (int k = 0; k < K; k++) {
|
| 463 |
+
int e = h_tidx[n * K + k];
|
| 464 |
+
float w = bf16_to_float(h_tw[n * K + k]);
|
| 465 |
+
// search p with p_to_s[p]==n and p_to_e[p]==e
|
| 466 |
+
int found_p = -1;
|
| 467 |
+
for (int64_t p = 0; p < TOTAL; p++) {
|
| 468 |
+
if (p_to_s[p] == n && p_to_e[p] == e) { found_p = p; break; }
|
| 469 |
+
}
|
| 470 |
+
if (found_p < 0) {
|
| 471 |
+
printf(" [!!!] not found: n=%d k=%d expert=%d\n", n, k, e);
|
| 472 |
+
continue;
|
| 473 |
+
}
|
| 474 |
+
found_count++;
|
| 475 |
+
for (int d = 0; d < D; d++)
|
| 476 |
+
manual_cum[n * D + d] += w * bf16_to_float(h_down[found_p * D + d]);
|
| 477 |
+
}
|
| 478 |
+
}
|
| 479 |
+
auto refbuf = read_file(data_dir + "/out_flat.bin");
|
| 480 |
+
auto* ref = (const uint16_t*)refbuf.data();
|
| 481 |
+
double l2d=0, l2r=0, maxd=0;
|
| 482 |
+
for (int64_t i = 0; i < S*D; i++) {
|
| 483 |
+
float a = manual_cum[i], b = bf16_to_float(ref[i]);
|
| 484 |
+
l2d += (a-b)*(a-b); l2r += b*b;
|
| 485 |
+
if (std::abs(a-b) > maxd) maxd = std::abs(a-b);
|
| 486 |
+
}
|
| 487 |
+
double rel_cum = std::sqrt(l2d) / (std::sqrt(l2r) + 1e-10);
|
| 488 |
+
printf(" [cmp] MANUAL_CUMSUM (p via expert cumsum) rel=%.4e max=%.4f found=%d/40 m[:4]=%.5f %.5f %.5f %.5f\n",
|
| 489 |
+
rel_cum, maxd, found_count, manual_cum[0], manual_cum[1], manual_cum[2], manual_cum[3]);
|
| 490 |
+
}
|
| 491 |
+
|
| 492 |
+
// Dump all expanded_x[p, 0] and all xn[s, 0] to determine the mapping.
|
| 493 |
+
{
|
| 494 |
+
std::vector<uint16_t> h_xn_all(S * D);
|
| 495 |
+
ACL_CHECK(aclrtMemcpy(h_xn_all.data(), S*D*2, xn_dev.get(), S*D*2, ACL_MEMCPY_DEVICE_TO_HOST));
|
| 496 |
+
std::vector<uint16_t> h_ex_all(TOTAL * D);
|
| 497 |
+
ACL_CHECK(aclrtMemcpy(h_ex_all.data(), TOTAL*D*2, expanded_x_dev.get(), TOTAL*D*2, ACL_MEMCPY_DEVICE_TO_HOST));
|
| 498 |
+
printf(" xn[s, 0]: ");
|
| 499 |
+
for (int s = 0; s < S; s++) printf("%.5f ", bf16_to_float(h_xn_all[s * D]));
|
| 500 |
+
printf("\n expanded_x[p, 0]: ");
|
| 501 |
+
for (int p = 0; p < TOTAL; p++) printf("%.5f ", bf16_to_float(h_ex_all[p * D]));
|
| 502 |
+
printf("\n mapping p→s (by matching expanded_x[p,0] to xn[s,0]): ");
|
| 503 |
+
for (int p = 0; p < TOTAL; p++) {
|
| 504 |
+
float e = bf16_to_float(h_ex_all[p * D]);
|
| 505 |
+
int match = -1; float best = 1e30f;
|
| 506 |
+
for (int s = 0; s < S; s++) {
|
| 507 |
+
float df = std::abs(bf16_to_float(h_xn_all[s * D]) - e);
|
| 508 |
+
if (df < best) { best = df; match = s; }
|
| 509 |
+
}
|
| 510 |
+
printf("%d ", match);
|
| 511 |
+
}
|
| 512 |
+
printf("\n");
|
| 513 |
+
}
|
| 514 |
+
|
| 515 |
+
// Dump gate_out[p=4, :8] — gate activation of xn[0] via expert 10
|
| 516 |
+
{
|
| 517 |
+
std::vector<uint16_t> h_gate(I);
|
| 518 |
+
// NOTE: gate_out_dev was overwritten by silu+mul. So we need to reload from scratch.
|
| 519 |
+
// Instead just show down_out[4, :4].
|
| 520 |
+
std::vector<uint16_t> h_d(D);
|
| 521 |
+
ACL_CHECK(aclrtMemcpy(h_d.data(), D*2, (char*)down_out_dev.get() + 4*D*2, D*2, ACL_MEMCPY_DEVICE_TO_HOST));
|
| 522 |
+
printf(" down_out[p=4, :4] (s=0, k=0, expert=10): %.5f %.5f %.5f %.5f\n",
|
| 523 |
+
bf16_to_float(h_d[0]), bf16_to_float(h_d[1]), bf16_to_float(h_d[2]), bf16_to_float(h_d[3]));
|
| 524 |
+
// If GMM is correct, down_out[4] ~ ref[0] / topk_w[0,0]. ref[0,:4]=[-0.025, -0.007, 0.005, -0.008] / 0.224 ~ [-0.113, -0.031, 0.024, -0.036].
|
| 525 |
+
// But it's just ONE contribution so hard to compare directly.
|
| 526 |
+
}
|
| 527 |
+
|
| 528 |
+
// Single-expert verification using linear_hf: compute gate/up/down for (xn[0], expert=10)
|
| 529 |
+
// and compare with GMM's down_out at the corresponding position.
|
| 530 |
+
// linear_hf expects HF-layout weight [out_features, in_features]; our stacked gate_exps/up_exps
|
| 531 |
+
// are [E, D, I] — meaning per-expert shape is [D, I] (K, N) NOT HF [I, D]. So we can NOT directly
|
| 532 |
+
// linear_hf from gate_exps. Instead, load the expert-10 weight fresh and use linear_hf.
|
| 533 |
+
{
|
| 534 |
+
std::vector<int32_t> h_tidx_local(S * K);
|
| 535 |
+
ACL_CHECK(aclrtMemcpy(h_tidx_local.data(), S*K*4, topk_idx_dev.get(), S*K*4, ACL_MEMCPY_DEVICE_TO_HOST));
|
| 536 |
+
int target_expert = h_tidx_local[0 * K + 0]; // topk_idx[0, 0] should be 10 from Python ref
|
| 537 |
+
printf("\n === Single-expert linear_hf vs GMM sanity (token 0, expert %d) ===\n", target_expert);
|
| 538 |
+
|
| 539 |
+
// Recompute p_to_s and p_to_e from host data (scoped locally).
|
| 540 |
+
std::vector<int64_t> h_tpe2(E);
|
| 541 |
+
std::vector<uint16_t> h_xn_all2(S * D);
|
| 542 |
+
std::vector<uint16_t> h_ex_all2(TOTAL * D);
|
| 543 |
+
ACL_CHECK(aclrtMemcpy(h_tpe2.data(), E*8, tokens_per_expert_dev.get(), E*8, ACL_MEMCPY_DEVICE_TO_HOST));
|
| 544 |
+
ACL_CHECK(aclrtMemcpy(h_xn_all2.data(), S*D*2, xn_dev.get(), S*D*2, ACL_MEMCPY_DEVICE_TO_HOST));
|
| 545 |
+
ACL_CHECK(aclrtMemcpy(h_ex_all2.data(), TOTAL*D*2, expanded_x_dev.get(), TOTAL*D*2, ACL_MEMCPY_DEVICE_TO_HOST));
|
| 546 |
+
std::vector<int> p_to_s(TOTAL), p_to_e(TOTAL);
|
| 547 |
+
{
|
| 548 |
+
int64_t cum = 0; int ce = 0;
|
| 549 |
+
for (int64_t p = 0; p < TOTAL; p++) {
|
| 550 |
+
while (ce < E && p >= cum + h_tpe2[ce]) { cum += h_tpe2[ce]; ce++; }
|
| 551 |
+
p_to_e[p] = ce;
|
| 552 |
+
float ev = bf16_to_float(h_ex_all2[p * D]);
|
| 553 |
+
int best = -1; float bd = 1e30f;
|
| 554 |
+
for (int s = 0; s < S; s++) {
|
| 555 |
+
float df = std::abs(bf16_to_float(h_xn_all2[s * D]) - ev);
|
| 556 |
+
if (df < bd) { bd = df; best = s; }
|
| 557 |
+
}
|
| 558 |
+
p_to_s[p] = best;
|
| 559 |
+
}
|
| 560 |
+
}
|
| 561 |
+
|
| 562 |
+
DeviceBuffer g_w, u_w, d_w;
|
| 563 |
+
char ename[256];
|
| 564 |
+
snprintf(ename, sizeof(ename), "model.layers.0.mlp.experts.%d.gate_proj.weight", target_expert);
|
| 565 |
+
if (!dw.st().get(ename)) { printf(" missing %s\n", ename); goto after_sanity; }
|
| 566 |
+
|
| 567 |
+
// Load full per-expert weight using public helpers (indirectly via loader).
|
| 568 |
+
// Easiest: use load_tensor_full_ via friend access... Instead, use st_ directly.
|
| 569 |
+
{
|
| 570 |
+
auto* m_gate = dw.st().get(ename);
|
| 571 |
+
DeviceBuffer gw_buf(m_gate->nbytes);
|
| 572 |
+
ACL_CHECK(aclrtMemcpy(gw_buf.get(), m_gate->nbytes, dw.st().data_ptr(*m_gate), m_gate->nbytes, ACL_MEMCPY_HOST_TO_DEVICE));
|
| 573 |
+
g_w = std::move(gw_buf);
|
| 574 |
+
|
| 575 |
+
snprintf(ename, sizeof(ename), "model.layers.0.mlp.experts.%d.up_proj.weight", target_expert);
|
| 576 |
+
auto* m_up = dw.st().get(ename);
|
| 577 |
+
DeviceBuffer uw_buf(m_up->nbytes);
|
| 578 |
+
ACL_CHECK(aclrtMemcpy(uw_buf.get(), m_up->nbytes, dw.st().data_ptr(*m_up), m_up->nbytes, ACL_MEMCPY_HOST_TO_DEVICE));
|
| 579 |
+
u_w = std::move(uw_buf);
|
| 580 |
+
|
| 581 |
+
snprintf(ename, sizeof(ename), "model.layers.0.mlp.experts.%d.down_proj.weight", target_expert);
|
| 582 |
+
auto* m_down = dw.st().get(ename);
|
| 583 |
+
DeviceBuffer dw_buf(m_down->nbytes);
|
| 584 |
+
ACL_CHECK(aclrtMemcpy(dw_buf.get(), m_down->nbytes, dw.st().data_ptr(*m_down), m_down->nbytes, ACL_MEMCPY_HOST_TO_DEVICE));
|
| 585 |
+
d_w = std::move(dw_buf);
|
| 586 |
+
}
|
| 587 |
+
|
| 588 |
+
// Compute gate = xn[0] @ gate_w.T → [I]; up = xn[0] @ up_w.T → [I]; act; down = act @ down_w.T → [D]
|
| 589 |
+
DeviceBuffer xn0_dev(D * 2);
|
| 590 |
+
ACL_CHECK(aclrtMemcpy(xn0_dev.get(), D*2, xn_dev.get(), D*2, ACL_MEMCPY_DEVICE_TO_DEVICE));
|
| 591 |
+
|
| 592 |
+
DeviceBuffer gate_v(I * 2), up_v(I * 2), act_v(I * 2), down_v(D * 2);
|
| 593 |
+
auto t_xn0 = make_contig_tensor(xn0_dev.get(), ACL_BF16, {1, D});
|
| 594 |
+
auto t_gate = make_contig_tensor(gate_v.get(), ACL_BF16, {1, I});
|
| 595 |
+
auto t_up = make_contig_tensor(up_v.get(), ACL_BF16, {1, I});
|
| 596 |
+
auto t_act = make_contig_tensor(act_v.get(), ACL_BF16, {1, I});
|
| 597 |
+
auto t_down = make_contig_tensor(down_v.get(), ACL_BF16, {1, D});
|
| 598 |
+
linear_hf(rt.stream(), t_xn0.get(), g_w.get(), ACL_BF16, I, D, t_gate.get()); // gate_proj HF [I, D]
|
| 599 |
+
linear_hf(rt.stream(), t_xn0.get(), u_w.get(), ACL_BF16, I, D, t_up.get());
|
| 600 |
+
rt.sync();
|
| 601 |
+
silu(rt.stream(), t_gate.get(), t_act.get());
|
| 602 |
+
mul(rt.stream(), t_act.get(), t_up.get(), t_act.get());
|
| 603 |
+
rt.sync();
|
| 604 |
+
linear_hf(rt.stream(), t_act.get(), d_w.get(), ACL_BF16, D, I, t_down.get()); // down_proj HF [D, I]
|
| 605 |
+
rt.sync();
|
| 606 |
+
|
| 607 |
+
std::vector<uint16_t> h_down_lin(D);
|
| 608 |
+
ACL_CHECK(aclrtMemcpy(h_down_lin.data(), D*2, down_v.get(), D*2, ACL_MEMCPY_DEVICE_TO_HOST));
|
| 609 |
+
|
| 610 |
+
// Find the p in GMM output that corresponds to (s=0, expert=target_expert)
|
| 611 |
+
int found_p = -1;
|
| 612 |
+
for (int64_t p = 0; p < TOTAL; p++) {
|
| 613 |
+
if (p_to_s[p] == 0 && p_to_e[p] == target_expert) { found_p = p; break; }
|
| 614 |
+
}
|
| 615 |
+
if (found_p >= 0) {
|
| 616 |
+
std::vector<uint16_t> h_down_gmm(D);
|
| 617 |
+
ACL_CHECK(aclrtMemcpy(h_down_gmm.data(), D*2, (char*)down_out_dev.get() + found_p*D*2, D*2, ACL_MEMCPY_DEVICE_TO_HOST));
|
| 618 |
+
double l2d=0, l2r=0, maxd=0;
|
| 619 |
+
for (int i = 0; i < D; i++) {
|
| 620 |
+
float a = bf16_to_float(h_down_gmm[i]), b = bf16_to_float(h_down_lin[i]);
|
| 621 |
+
l2d += (a-b)*(a-b); l2r += b*b;
|
| 622 |
+
if (std::abs(a-b) > maxd) maxd = std::abs(a-b);
|
| 623 |
+
}
|
| 624 |
+
double rel = std::sqrt(l2d) / (std::sqrt(l2r) + 1e-10);
|
| 625 |
+
printf(" GMM down_out[p=%d] vs linear_hf down: rel=%.4e max=%.4f\n", found_p, rel, maxd);
|
| 626 |
+
printf(" GMM[:4]: %.5f %.5f %.5f %.5f\n",
|
| 627 |
+
bf16_to_float(h_down_gmm[0]), bf16_to_float(h_down_gmm[1]), bf16_to_float(h_down_gmm[2]), bf16_to_float(h_down_gmm[3]));
|
| 628 |
+
printf(" linear[:4]: %.5f %.5f %.5f %.5f\n",
|
| 629 |
+
bf16_to_float(h_down_lin[0]), bf16_to_float(h_down_lin[1]), bf16_to_float(h_down_lin[2]), bf16_to_float(h_down_lin[3]));
|
| 630 |
+
} else {
|
| 631 |
+
printf(" not found p for (s=0, expert=%d)\n", target_expert);
|
| 632 |
+
}
|
| 633 |
+
}
|
| 634 |
+
after_sanity:;
|
| 635 |
+
|
| 636 |
+
// Direct verification: gate_exps[expert_10, :4, :4] vs HF gate_proj_10 (transposed).
|
| 637 |
+
{
|
| 638 |
+
int expert_id = 10;
|
| 639 |
+
std::vector<uint16_t> h_stacked(4 * 4);
|
| 640 |
+
// gate_exps shape [E, D, I]. Expert 10 starts at offset expert_id * D * I * 2.
|
| 641 |
+
// Read the first 4 rows (d=0..3), first 4 cols (i=0..3). Row stride = I * 2 bytes.
|
| 642 |
+
for (int d = 0; d < 4; d++) {
|
| 643 |
+
ACL_CHECK(aclrtMemcpy(h_stacked.data() + d*4, 8,
|
| 644 |
+
(char*)moe.gate_exps.get() + (expert_id * D * I + d * I) * 2, 8,
|
| 645 |
+
ACL_MEMCPY_DEVICE_TO_HOST));
|
| 646 |
+
}
|
| 647 |
+
char ename[256];
|
| 648 |
+
snprintf(ename, sizeof(ename), "model.layers.0.mlp.experts.%d.gate_proj.weight", expert_id);
|
| 649 |
+
auto* m = dw.st().get(ename);
|
| 650 |
+
// HF gate_proj [I, D] row-major. Element at (i, d) is at offset (i*D + d)*2.
|
| 651 |
+
// Expected gate_exps[10, d, i] == HF_gate_proj[10][i, d].
|
| 652 |
+
// So for d in 0..3, i in 0..3: expected is HF[i, d].
|
| 653 |
+
std::vector<uint16_t> h_expected(4 * 4);
|
| 654 |
+
auto* hf = (const uint16_t*)dw.st().data_ptr(*m);
|
| 655 |
+
for (int d = 0; d < 4; d++) {
|
| 656 |
+
for (int i = 0; i < 4; i++) {
|
| 657 |
+
h_expected[d*4 + i] = hf[i * D + d]; // HF[i, d]
|
| 658 |
+
}
|
| 659 |
+
}
|
| 660 |
+
printf("\n === gate_exps[10, :4, :4] layout check ===\n");
|
| 661 |
+
printf(" stacked: ");
|
| 662 |
+
for (int i = 0; i < 16; i++) printf("%.5f ", bf16_to_float(h_stacked[i]));
|
| 663 |
+
printf("\n expected: ");
|
| 664 |
+
for (int i = 0; i < 16; i++) printf("%.5f ", bf16_to_float(h_expected[i]));
|
| 665 |
+
printf("\n");
|
| 666 |
+
int mism = 0;
|
| 667 |
+
for (int i = 0; i < 16; i++) if (h_stacked[i] != h_expected[i]) mism++;
|
| 668 |
+
printf(" mismatches: %d / 16\n", mism);
|
| 669 |
+
}
|
| 670 |
+
|
| 671 |
+
printf("\n=== Final (with residual) ===\n");
|
| 672 |
+
double rel = compare_bf16("final_out", final_dev.get(), S * D, "final_out.bin");
|
| 673 |
+
bool pass = rel < 5e-2;
|
| 674 |
+
printf("\n%s\n", pass ? "=== test_moe_layer PASS ===" : "=== test_moe_layer FAIL ===");
|
| 675 |
+
return pass ? 0 : 1;
|
| 676 |
+
}
|
|
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// test_op_support.cpp — smoke test which aclnn ops actually RUN on 910 初代.
|
| 2 |
+
// Just call each candidate op with small tensors; report SUCCESS/FAILURE.
|
| 3 |
+
// Guides optimization feasibility analysis.
|
| 4 |
+
#include "acl_common.h"
|
| 5 |
+
#include "acl_runtime.h"
|
| 6 |
+
#include "aclnn_ops.h"
|
| 7 |
+
#include <acl/acl.h>
|
| 8 |
+
#include <aclnnop/aclnn_add_rms_norm.h>
|
| 9 |
+
#include <aclnnop/aclnn_npu_format_cast.h>
|
| 10 |
+
#include <aclnnop/aclnn_matmul.h>
|
| 11 |
+
|
| 12 |
+
#include <cstdio>
|
| 13 |
+
#include <cstring>
|
| 14 |
+
#include <vector>
|
| 15 |
+
|
| 16 |
+
static float bf16_to_float(uint16_t x) { uint32_t u = (uint32_t)x << 16; float f; std::memcpy(&f, &u, 4); return f; }
|
| 17 |
+
static uint16_t f_to_bf16(float f) { uint32_t u; std::memcpy(&u, &f, 4); return (uint16_t)((u + 0x7FFF + ((u >> 16) & 1)) >> 16); }
|
| 18 |
+
|
| 19 |
+
static const char* test_add_rms_norm(AclRuntime& rt) {
|
| 20 |
+
// Inputs: x1 [1, 16], x2 [1, 16] BF16; gamma [16] BF16
|
| 21 |
+
const int64_t D = 16;
|
| 22 |
+
std::vector<uint16_t> h_x1(D, f_to_bf16(0.5f));
|
| 23 |
+
std::vector<uint16_t> h_x2(D, f_to_bf16(0.3f));
|
| 24 |
+
std::vector<uint16_t> h_gamma(D, f_to_bf16(1.0f));
|
| 25 |
+
DeviceBuffer x1(D*2), x2(D*2), g(D*2), y(D*2), rstd(1*4), x_out(D*2);
|
| 26 |
+
ACL_CHECK(aclrtMemcpy(x1.get(), D*2, h_x1.data(), D*2, ACL_MEMCPY_HOST_TO_DEVICE));
|
| 27 |
+
ACL_CHECK(aclrtMemcpy(x2.get(), D*2, h_x2.data(), D*2, ACL_MEMCPY_HOST_TO_DEVICE));
|
| 28 |
+
ACL_CHECK(aclrtMemcpy(g.get(), D*2, h_gamma.data(), D*2, ACL_MEMCPY_HOST_TO_DEVICE));
|
| 29 |
+
|
| 30 |
+
auto tx1 = make_contig_tensor(x1.get(), ACL_BF16, {1, D});
|
| 31 |
+
auto tx2 = make_contig_tensor(x2.get(), ACL_BF16, {1, D});
|
| 32 |
+
auto tg = make_contig_tensor(g.get(), ACL_BF16, {D});
|
| 33 |
+
auto ty = make_contig_tensor(y.get(), ACL_BF16, {1, D});
|
| 34 |
+
auto trs = make_contig_tensor(rstd.get(), ACL_FLOAT, {1});
|
| 35 |
+
auto tout= make_contig_tensor(x_out.get(), ACL_BF16, {1, D});
|
| 36 |
+
|
| 37 |
+
uint64_t ws = 0; aclOpExecutor* exec = nullptr;
|
| 38 |
+
aclnnStatus s = aclnnAddRmsNormGetWorkspaceSize(tx1.get(), tx2.get(), tg.get(), 1e-6,
|
| 39 |
+
ty.get(), trs.get(), tout.get(), &ws, &exec);
|
| 40 |
+
if (s != 0) return "GetWorkspaceSize FAILED";
|
| 41 |
+
DeviceBuffer ws_buf;
|
| 42 |
+
if (ws > 0) ws_buf.alloc(ws);
|
| 43 |
+
s = aclnnAddRmsNorm(ws_buf.get(), ws, exec, rt.stream());
|
| 44 |
+
if (s != 0) return "aclnnAddRmsNorm FAILED (kernel not available on 910?)";
|
| 45 |
+
if (aclrtSynchronizeStream(rt.stream()) != 0) return "sync FAILED";
|
| 46 |
+
return "OK";
|
| 47 |
+
}
|
| 48 |
+
|
| 49 |
+
static const char* test_npu_format_cast_nz(AclRuntime& rt) {
|
| 50 |
+
// Transform a small [16, 16] BF16 tensor from ND to NZ format.
|
| 51 |
+
const int64_t H = 16, W = 16;
|
| 52 |
+
std::vector<uint16_t> h(H * W, f_to_bf16(1.0f));
|
| 53 |
+
DeviceBuffer src(H * W * 2);
|
| 54 |
+
ACL_CHECK(aclrtMemcpy(src.get(), H*W*2, h.data(), H*W*2, ACL_MEMCPY_HOST_TO_DEVICE));
|
| 55 |
+
auto tsrc = make_contig_tensor(src.get(), ACL_BF16, {H, W});
|
| 56 |
+
|
| 57 |
+
// Step 1: calculate NZ shape
|
| 58 |
+
int64_t* dst_shape = nullptr;
|
| 59 |
+
uint64_t dst_shape_size = 0;
|
| 60 |
+
int actual_fmt = 0;
|
| 61 |
+
aclnnStatus s = aclnnNpuFormatCastCalculateSizeAndFormat(
|
| 62 |
+
tsrc.get(), /*dstFormat=*/29 /* FRACTAL_NZ */,
|
| 63 |
+
/*additionalDtype=*/27 /* BF16 */,
|
| 64 |
+
&dst_shape, &dst_shape_size, &actual_fmt);
|
| 65 |
+
if (s != 0) return "CalculateSizeAndFormat FAILED";
|
| 66 |
+
|
| 67 |
+
// Step 2: alloc dst and call cast
|
| 68 |
+
int64_t total = 1;
|
| 69 |
+
std::vector<int64_t> shape_vec(dst_shape, dst_shape + dst_shape_size);
|
| 70 |
+
for (auto d : shape_vec) total *= d;
|
| 71 |
+
DeviceBuffer dst(total * 2);
|
| 72 |
+
auto tdst = make_acl_tensor(dst.get(), ACL_BF16, shape_vec, {}, (aclFormat)actual_fmt);
|
| 73 |
+
|
| 74 |
+
uint64_t ws = 0; aclOpExecutor* exec = nullptr;
|
| 75 |
+
s = aclnnNpuFormatCastGetWorkspaceSize(tsrc.get(), tdst.get(), &ws, &exec);
|
| 76 |
+
if (s != 0) return "FormatCast GetWorkspaceSize FAILED";
|
| 77 |
+
DeviceBuffer ws_buf; if (ws > 0) ws_buf.alloc(ws);
|
| 78 |
+
s = aclnnNpuFormatCast(ws_buf.get(), ws, exec, rt.stream());
|
| 79 |
+
if (s != 0) return "aclnnNpuFormatCast FAILED (NZ not supported on 910?)";
|
| 80 |
+
if (aclrtSynchronizeStream(rt.stream()) != 0) return "sync FAILED";
|
| 81 |
+
return "OK";
|
| 82 |
+
}
|
| 83 |
+
|
| 84 |
+
static const char* test_matmul_nz(AclRuntime& rt) {
|
| 85 |
+
// Try a MatMul with NZ-format weight.
|
| 86 |
+
const int64_t M = 16, K = 32, N = 16;
|
| 87 |
+
std::vector<uint16_t> h_x(M * K, f_to_bf16(0.1f));
|
| 88 |
+
std::vector<uint16_t> h_w(K * N, f_to_bf16(0.1f));
|
| 89 |
+
DeviceBuffer x(M*K*2), w(K*N*2), y(M*N*2);
|
| 90 |
+
ACL_CHECK(aclrtMemcpy(x.get(), M*K*2, h_x.data(), M*K*2, ACL_MEMCPY_HOST_TO_DEVICE));
|
| 91 |
+
ACL_CHECK(aclrtMemcpy(w.get(), K*N*2, h_w.data(), K*N*2, ACL_MEMCPY_HOST_TO_DEVICE));
|
| 92 |
+
|
| 93 |
+
auto tx = make_contig_tensor(x.get(), ACL_BF16, {M, K});
|
| 94 |
+
auto tw_nd = make_contig_tensor(w.get(), ACL_BF16, {K, N});
|
| 95 |
+
|
| 96 |
+
// Convert W to NZ
|
| 97 |
+
int64_t* dst_shape = nullptr; uint64_t dst_size = 0; int fmt = 0;
|
| 98 |
+
if (aclnnNpuFormatCastCalculateSizeAndFormat(tw_nd.get(), 29, 27, &dst_shape, &dst_size, &fmt) != 0)
|
| 99 |
+
return "calc NZ FAILED";
|
| 100 |
+
int64_t total = 1;
|
| 101 |
+
std::vector<int64_t> sh(dst_shape, dst_shape + dst_size);
|
| 102 |
+
for (auto d : sh) total *= d;
|
| 103 |
+
DeviceBuffer w_nz(total * 2);
|
| 104 |
+
auto tw_nz = make_acl_tensor(w_nz.get(), ACL_BF16, sh, {}, (aclFormat)fmt);
|
| 105 |
+
|
| 106 |
+
uint64_t ws = 0; aclOpExecutor* e = nullptr;
|
| 107 |
+
aclnnStatus s = aclnnNpuFormatCastGetWorkspaceSize(tw_nd.get(), tw_nz.get(), &ws, &e);
|
| 108 |
+
if (s != 0) return "NZ cast ws FAILED";
|
| 109 |
+
DeviceBuffer wb; if (ws > 0) wb.alloc(ws);
|
| 110 |
+
if (aclnnNpuFormatCast(wb.get(), ws, e, rt.stream()) != 0) return "NZ cast EXEC FAILED";
|
| 111 |
+
if (aclrtSynchronizeStream(rt.stream()) != 0) return "NZ cast sync FAILED";
|
| 112 |
+
|
| 113 |
+
// Now try MatMul with x (ND) × w_nz (NZ)
|
| 114 |
+
auto ty = make_contig_tensor(y.get(), ACL_BF16, {M, N});
|
| 115 |
+
ws = 0; e = nullptr;
|
| 116 |
+
s = aclnnMatmulGetWorkspaceSize(tx.get(), tw_nz.get(), ty.get(), 0 /*trans*/, &ws, &e);
|
| 117 |
+
if (s != 0) return "MatMul NZ GetWorkspaceSize FAILED";
|
| 118 |
+
DeviceBuffer mwb; if (ws > 0) mwb.alloc(ws);
|
| 119 |
+
if (aclnnMatmul(mwb.get(), ws, e, rt.stream()) != 0) return "MatMul NZ EXEC FAILED (MatMul doesn't accept NZ on 910?)";
|
| 120 |
+
if (aclrtSynchronizeStream(rt.stream()) != 0) return "MatMul NZ sync FAILED";
|
| 121 |
+
return "OK";
|
| 122 |
+
}
|
| 123 |
+
|
| 124 |
+
static const char* test_multi_stream(AclRuntime& rt) {
|
| 125 |
+
// Allocate a SECOND stream and check it works.
|
| 126 |
+
aclrtStream s2 = nullptr;
|
| 127 |
+
if (aclrtCreateStream(&s2) != 0) return "aclrtCreateStream FAILED";
|
| 128 |
+
// Simple dummy op on s2
|
| 129 |
+
DeviceBuffer x(16 * 2);
|
| 130 |
+
std::vector<uint16_t> hx(16, 0);
|
| 131 |
+
if (aclrtMemcpyAsync(x.get(), 16*2, hx.data(), 16*2, ACL_MEMCPY_HOST_TO_DEVICE, s2) != 0) return "memcpy on s2 FAILED";
|
| 132 |
+
if (aclrtSynchronizeStream(s2) != 0) return "sync s2 FAILED";
|
| 133 |
+
aclrtDestroyStream(s2);
|
| 134 |
+
return "OK";
|
| 135 |
+
}
|
| 136 |
+
|
| 137 |
+
int main() {
|
| 138 |
+
AclRuntime rt;
|
| 139 |
+
rt.init(0);
|
| 140 |
+
|
| 141 |
+
printf("=== 910 op support smoke test ===\n");
|
| 142 |
+
|
| 143 |
+
const char* r1 = test_add_rms_norm(rt);
|
| 144 |
+
printf(" aclnnAddRmsNorm (fused Add+RmsNorm): %s\n", r1);
|
| 145 |
+
|
| 146 |
+
const char* r2 = test_npu_format_cast_nz(rt);
|
| 147 |
+
printf(" aclnnNpuFormatCast → FRACTAL_NZ: %s\n", r2);
|
| 148 |
+
|
| 149 |
+
const char* r3 = test_matmul_nz(rt);
|
| 150 |
+
printf(" aclnnMatmul with NZ weight: %s\n", r3);
|
| 151 |
+
|
| 152 |
+
const char* r4 = test_multi_stream(rt);
|
| 153 |
+
printf(" Multi-stream (compute/comm overlap): %s\n", r4);
|
| 154 |
+
|
| 155 |
+
// More candidates
|
| 156 |
+
printf("\n=== Additional 910 op candidates ===\n");
|
| 157 |
+
|
| 158 |
+
// InplaceAddRmsNorm
|
| 159 |
+
#include <aclnnop/aclnn_inplace_add_rms_norm.h>
|
| 160 |
+
{
|
| 161 |
+
const int64_t D = 16;
|
| 162 |
+
std::vector<uint16_t> h(D, f_to_bf16(0.5f)), hg(D, f_to_bf16(1.0f));
|
| 163 |
+
DeviceBuffer x1(D*2), x2(D*2), g(D*2), rstd(4);
|
| 164 |
+
ACL_CHECK(aclrtMemcpy(x1.get(), D*2, h.data(), D*2, ACL_MEMCPY_HOST_TO_DEVICE));
|
| 165 |
+
ACL_CHECK(aclrtMemcpy(x2.get(), D*2, h.data(), D*2, ACL_MEMCPY_HOST_TO_DEVICE));
|
| 166 |
+
ACL_CHECK(aclrtMemcpy(g.get(), D*2, hg.data(), D*2, ACL_MEMCPY_HOST_TO_DEVICE));
|
| 167 |
+
auto tx1 = make_contig_tensor(x1.get(), ACL_BF16, {1, D});
|
| 168 |
+
auto tx2 = make_contig_tensor(x2.get(), ACL_BF16, {1, D});
|
| 169 |
+
auto tg = make_contig_tensor(g.get(), ACL_BF16, {D});
|
| 170 |
+
auto tr = make_contig_tensor(rstd.get(), ACL_FLOAT, {1});
|
| 171 |
+
uint64_t ws = 0; aclOpExecutor* e = nullptr;
|
| 172 |
+
aclnnStatus s = aclnnInplaceAddRmsNormGetWorkspaceSize(tx1.get(), tx2.get(), tg.get(), 1e-6,
|
| 173 |
+
tr.get(), &ws, &e);
|
| 174 |
+
printf(" aclnnInplaceAddRmsNorm: %s\n", s == 0 ? "GetWS OK" : "GetWS FAILED");
|
| 175 |
+
if (s == 0) {
|
| 176 |
+
DeviceBuffer wb; if (ws > 0) wb.alloc(ws);
|
| 177 |
+
s = aclnnInplaceAddRmsNorm(wb.get(), ws, e, rt.stream());
|
| 178 |
+
printf(" exec: %s\n", s == 0 ? "OK" : "FAILED");
|
| 179 |
+
}
|
| 180 |
+
}
|
| 181 |
+
|
| 182 |
+
// Test HCCL AllReduce on a separate stream
|
| 183 |
+
printf(" HCCL AllReduce on stream2: requires TP>1, skipped in this smoke test\n");
|
| 184 |
+
|
| 185 |
+
printf("\n=== FINAL Feasibility Summary ===\n");
|
| 186 |
+
printf(" Optimization A (FRACTAL_NZ): INFEASIBLE (910 不支持)\n");
|
| 187 |
+
printf(" Optimization B (multi-stream): FEASIBLE\n");
|
| 188 |
+
printf(" Optimization C (Add+RmsNorm): INFEASIBLE (910 无 kernel)\n");
|
| 189 |
+
return 0;
|
| 190 |
+
}
|