xianglarry Claude Opus 4.7 (1M context) commited on
Commit
4b9fefd
·
1 Parent(s): 1e99b3b

Initial C++ aclnn EAGER inference for Qwen3-235B-A22B MoE on Ascend 910 × 16 NPU

Browse files

Pure-C++ inference runtime built directly on aclnn single-op API (no graph
compilation, no PyTorch, no ggml). Targets Qwen3-235B-A22B-Instruct-2507 BF16
with TP=16 HCCL tensor parallelism.

Quality-preserving throughput:
- Untuned baseline: 12 t/s
- Recommended (HCCL env + Fused RoPE + small ops): ~27 t/s (all prompts)
- PLD with degeneration guard: 29-45 t/s (structured long-form text)

Key components:
- 12 headers + 6 sources implementing attention/MoE/Runner/HCCL/RoPE
- Fused RoPE via aclnnApplyRotaryPosEmbV2 (layout=1, "half")
- PLD (Prompt Lookup Decoding) with degeneration guard:
low-distinct + tail-echo heuristics block loop-amplifying drafts
- bench_pld_safe.sh classifies each run as OK / LOOP_N / LOW_DIVERSITY
and separates TG stats accordingly (honest performance reporting)
- 19 unit / integration tests + end-to-end smoke test

HCCL environment (applied by tp_launch.sh):
HCCL_OP_EXPANSION_MODE=AIV + HCCL_OP_BASE_FFTS_MODE_ENABLE=1 +
TASK_QUEUE_ENABLE=2 contributes +89% TG vs default ring-only.

Known limitations:
- Does not exceed cann-recipes-infer GE graph baseline of ~54 t/s
- PLD on factual/code prompts is unreliable (disable or use bench_pld_safe.sh)
- Requires Ascend 910 initial-gen × 16 NPU and CANN 8.5.1

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +48 -0
  2. CMakeLists.txt +110 -0
  3. LICENSE +176 -0
  4. README.md +338 -0
  5. external/json.hpp +0 -0
  6. include/acl_common.h +106 -0
  7. include/acl_runtime.h +41 -0
  8. include/aclnn_ops.h +345 -0
  9. include/device_weights.h +82 -0
  10. include/engine.h +354 -0
  11. include/hccl_comm.h +106 -0
  12. include/model_config.h +52 -0
  13. include/rope.h +94 -0
  14. include/runner.h +128 -0
  15. include/safetensors_loader.h +78 -0
  16. include/tokenizer.h +38 -0
  17. include/workspace_pool.h +84 -0
  18. scripts/bench_hccl.sh +56 -0
  19. scripts/bench_hccl_adv.sh +56 -0
  20. scripts/bench_hccl_adv2.sh +56 -0
  21. scripts/bench_pld.sh +69 -0
  22. scripts/bench_pld_k.sh +41 -0
  23. scripts/bench_pld_safe.sh +154 -0
  24. scripts/bench_tg.sh +40 -0
  25. scripts/export_vocab.py +85 -0
  26. scripts/gen_attention_reference.py +179 -0
  27. scripts/gen_gmm_reference.py +89 -0
  28. scripts/gen_mm_reference.py +23 -0
  29. scripts/gen_moe_reference.py +115 -0
  30. scripts/gen_rms_norm_reference.py +39 -0
  31. scripts/regen_rope_reference.py +62 -0
  32. scripts/tp_launch.sh +58 -0
  33. src/device_weights.cpp +221 -0
  34. src/main_cli.cpp +816 -0
  35. src/model_config.cpp +115 -0
  36. src/runner.cpp +428 -0
  37. src/safetensors_loader.cpp +172 -0
  38. src/tokenizer.cpp +176 -0
  39. tests/hello_acl.cpp +62 -0
  40. tests/test_attention_decode.cpp +319 -0
  41. tests/test_attention_layer.cpp +219 -0
  42. tests/test_batch_correctness.cpp +98 -0
  43. tests/test_batch_decode.cpp +85 -0
  44. tests/test_chat_flow.sh +72 -0
  45. tests/test_engine_smoke.cpp +8 -0
  46. tests/test_layer_forward.cpp +192 -0
  47. tests/test_linear_hf.cpp +73 -0
  48. tests/test_model_config.cpp +106 -0
  49. tests/test_moe_layer.cpp +676 -0
  50. tests/test_op_support.cpp +190 -0
.gitignore ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Build artifacts
2
+ /build/
3
+ *.o
4
+ *.obj
5
+ *.a
6
+ *.so
7
+ *.exe
8
+
9
+ # CMake
10
+ CMakeCache.txt
11
+ CMakeFiles/
12
+ cmake_install.cmake
13
+ Makefile
14
+ compile_commands.json
15
+
16
+ # Tokenizer output (regenerated from HF model)
17
+ /tokenizer_data/
18
+ *.bin
19
+ !tests/**/*.bin
20
+
21
+ # Reference data (regenerated by scripts/gen_*.py; too large to commit)
22
+ /tests/attn_data/
23
+ /tests/moe_data/
24
+ /tests/mm_data/
25
+ /tests/rms_norm_data/
26
+ /tests/poc_data/
27
+
28
+ # Runtime state
29
+ /tmp/
30
+ *.log
31
+ /tp_rank_*.log
32
+ hccl_root_info.bin
33
+ /tmp/hccl_root_info.bin
34
+
35
+ # Editor / IDE
36
+ .vscode/
37
+ .idea/
38
+ *.swp
39
+ *.swo
40
+ .DS_Store
41
+
42
+ # Python
43
+ __pycache__/
44
+ *.pyc
45
+ .ipynb_checkpoints/
46
+
47
+ # Benchmark output
48
+ bench_result*.log
CMakeLists.txt ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ cmake_minimum_required(VERSION 3.16)
2
+ project(qwen3-moe-aclnn CXX)
3
+
4
+ set(CMAKE_CXX_STANDARD 17)
5
+ set(CMAKE_CXX_STANDARD_REQUIRED ON)
6
+ set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
7
+
8
+ if(NOT CMAKE_BUILD_TYPE)
9
+ set(CMAKE_BUILD_TYPE Release)
10
+ endif()
11
+
12
+ set(CMAKE_CXX_FLAGS_RELEASE "-O2 -g")
13
+ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall -Wno-unused-function")
14
+
15
+ # CANN paths
16
+ if(NOT DEFINED CANN_INSTALL_DIR)
17
+ if(DEFINED ENV{ASCEND_TOOLKIT_HOME})
18
+ set(CANN_INSTALL_DIR $ENV{ASCEND_TOOLKIT_HOME})
19
+ else()
20
+ set(CANN_INSTALL_DIR /usr/local/Ascend/ascend-toolkit/latest)
21
+ endif()
22
+ endif()
23
+ message(STATUS "CANN_INSTALL_DIR: ${CANN_INSTALL_DIR}")
24
+
25
+ include_directories(
26
+ ${CANN_INSTALL_DIR}/include
27
+ ${CANN_INSTALL_DIR}/include/aclnn
28
+ ${CMAKE_SOURCE_DIR}/include
29
+ ${CMAKE_SOURCE_DIR}/external
30
+ )
31
+
32
+ link_directories(${CANN_INSTALL_DIR}/lib64)
33
+
34
+ set(CANN_LIBS ascendcl nnopbase opapi opapi_transformer acl_op_compiler hccl)
35
+
36
+ # HCCL headers live under include/ but we need explicit include dir for <hccl/hccl.h>.
37
+ include_directories(${CANN_INSTALL_DIR}/include/hccl)
38
+
39
+ # ---- Library: qwen3-moe-aclnn core ----
40
+ set(LCA_SOURCES
41
+ src/safetensors_loader.cpp
42
+ src/model_config.cpp
43
+ src/tokenizer.cpp
44
+ src/device_weights.cpp
45
+ src/runner.cpp
46
+ )
47
+ add_library(qwen3-moe-aclnn-core STATIC ${LCA_SOURCES})
48
+ target_link_libraries(qwen3-moe-aclnn-core PUBLIC ${CANN_LIBS})
49
+
50
+ # ---- Binaries ----
51
+ add_executable(hello_acl tests/hello_acl.cpp)
52
+ target_link_libraries(hello_acl qwen3-moe-aclnn-core)
53
+
54
+ add_executable(test_safetensors tests/test_safetensors.cpp)
55
+ target_link_libraries(test_safetensors qwen3-moe-aclnn-core)
56
+
57
+ add_executable(test_model_config tests/test_model_config.cpp)
58
+ target_link_libraries(test_model_config qwen3-moe-aclnn-core)
59
+
60
+ add_executable(test_tokenizer tests/test_tokenizer.cpp)
61
+ target_link_libraries(test_tokenizer qwen3-moe-aclnn-core)
62
+
63
+ add_executable(test_rms_norm tests/test_rms_norm.cpp)
64
+ target_link_libraries(test_rms_norm qwen3-moe-aclnn-core)
65
+
66
+ add_executable(test_weight_load tests/test_weight_load.cpp)
67
+ target_link_libraries(test_weight_load qwen3-moe-aclnn-core)
68
+
69
+ add_executable(test_linear_hf tests/test_linear_hf.cpp)
70
+ target_link_libraries(test_linear_hf qwen3-moe-aclnn-core)
71
+
72
+ add_executable(test_rope tests/test_rope.cpp)
73
+ target_link_libraries(test_rope qwen3-moe-aclnn-core)
74
+
75
+ add_executable(test_rope_manual tests/test_rope_manual.cpp)
76
+ target_link_libraries(test_rope_manual qwen3-moe-aclnn-core)
77
+
78
+ add_executable(test_attention_layer tests/test_attention_layer.cpp)
79
+ target_link_libraries(test_attention_layer qwen3-moe-aclnn-core)
80
+
81
+ add_executable(test_moe_layer tests/test_moe_layer.cpp)
82
+ target_link_libraries(test_moe_layer qwen3-moe-aclnn-core)
83
+
84
+ add_executable(test_attention_decode tests/test_attention_decode.cpp)
85
+ target_link_libraries(test_attention_decode qwen3-moe-aclnn-core)
86
+
87
+ add_executable(test_engine_smoke tests/test_engine_smoke.cpp)
88
+ target_link_libraries(test_engine_smoke qwen3-moe-aclnn-core)
89
+
90
+ add_executable(test_layer_forward tests/test_layer_forward.cpp)
91
+ target_link_libraries(test_layer_forward qwen3-moe-aclnn-core)
92
+
93
+ add_executable(test_runner tests/test_runner.cpp)
94
+ target_link_libraries(test_runner qwen3-moe-aclnn-core)
95
+
96
+ # ---- Main CLI ----
97
+ add_executable(qwen3-moe-aclnn src/main_cli.cpp)
98
+ target_link_libraries(qwen3-moe-aclnn qwen3-moe-aclnn-core)
99
+
100
+ add_executable(test_op_support tests/test_op_support.cpp)
101
+ target_link_libraries(test_op_support qwen3-moe-aclnn-core)
102
+
103
+ add_executable(test_rope_fused tests/test_rope_fused.cpp)
104
+ target_link_libraries(test_rope_fused qwen3-moe-aclnn-core)
105
+
106
+ add_executable(test_batch_decode tests/test_batch_decode.cpp)
107
+ target_link_libraries(test_batch_decode qwen3-moe-aclnn-core)
108
+
109
+ add_executable(test_batch_correctness tests/test_batch_correctness.cpp)
110
+ target_link_libraries(test_batch_correctness qwen3-moe-aclnn-core)
LICENSE ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for describing the origin of the Work and
141
+ reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Support. While redistributing the Work or
166
+ Derivative Works thereof, You may choose to offer, and charge a
167
+ fee for, acceptance of support, warranty, indemnity, or other
168
+ liability obligations and/or rights consistent with this License.
169
+ However, in accepting such obligations, You may act only on Your
170
+ own behalf and on Your sole responsibility, not on behalf of any
171
+ other Contributor, and only if You agree to indemnify, defend,
172
+ and hold each Contributor harmless for any liability incurred by,
173
+ or claims asserted against, such Contributor by reason of your
174
+ accepting any such warranty or support.
175
+
176
+ END OF TERMS AND CONDITIONS
README.md ADDED
@@ -0,0 +1,338 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # qwen3-moe-aclnn
2
+
3
+ Pure C++ inference of **Qwen3-235B-A22B-Instruct** BF16 on **Ascend 910 × 16 NPU**, built directly on the aclnn EAGER API (no graph compilation, no PyTorch, no ggml).
4
+
5
+ ---
6
+
7
+ ## Performance
8
+
9
+ Measured on Ascend 910 initial-gen × 16 NPU (TP=16) with Qwen3-235B-A22B-Instruct-2507 BF16 weights.
10
+ All numbers are **quality-preserving TG** (output was manually verified); greedy `temperature=0`.
11
+
12
+ | Configuration | TG | Applicable prompts |
13
+ |---|---|---|
14
+ | Untuned baseline | 12 t/s | All |
15
+ | **Default recommended** (no PLD) | **~27 t/s** | **All prompts, stable output** |
16
+ | PLD with degeneration guard | 29-45 t/s | Structured text (essays, long-form answers) |
17
+ | PLD on creative prompts | 25-40 t/s | Stories / varied generation |
18
+ | PLD on factual / code prompts | unstable (21-95 t/s, high variance) | Not recommended |
19
+
20
+ Reference: `cann-recipes-infer` GE graph baseline reports ~54 t/s on the same hardware. **This project does not exceed that baseline** — it trades some peak speed for (a) no graph compilation, (b) no PyTorch dependency, (c) full control over operator scheduling.
21
+
22
+ ### Key optimizations that contributed (in order of magnitude)
23
+
24
+ | Rank | Optimization | Gain | Where |
25
+ |---|---|---|---|
26
+ | 🥇 | HCCL env tuning (`AIV` + `FFTS` + `TASK_QUEUE=2`) | +89% (12→23 t/s) | `scripts/tp_launch.sh` |
27
+ | 🥈 | Fused RoPE via `aclnnApplyRotaryPosEmbV2` | +17% (23→27 t/s) | `include/rope.h` |
28
+ | 🥉 | Prompt Lookup Decoding (PLD) w/ degeneration guard | +10-60% on applicable prompts | `src/main_cli.cpp` |
29
+ | ○ | Device-side topk-w normalize, MoE argsort, cos/sin cache | ~+15% cumulative | `include/engine.h` |
30
+ | ○ | WorkspacePool (thread-local + retain-old) | reduces alloc overhead | `include/workspace_pool.h` |
31
+
32
+ ---
33
+
34
+ ## Architecture
35
+
36
+ **Model**: Qwen3-235B-A22B, 94 layers, 128 experts (top-k=8), GQA (64 Q heads, 4 KV heads), BF16.
37
+
38
+ **Parallelism**: TP=16 via HCCL ring AllReduce. KV heads sharded 1-per-rank (since 4 KV heads < 16 ranks, Q heads 0-3 on each rank share KV head 0).
39
+
40
+ **Execution**: aclnn EAGER mode — every op goes through `aclnn*` single-op API with workspace pool; no graph capture, no GE IR. Async stream execution with `TASK_QUEUE_ENABLE=2` for kernel submission overlap.
41
+
42
+ **Tokenizer**: Uses HuggingFace `transformers` via a Python subprocess for encoding; vocab decode is pure C++ from an exported `vocab.bin`.
43
+
44
+ ### Per-layer forward flow
45
+
46
+ ```
47
+ x_in [S, D=4096]
48
+
49
+ ┌── Attention branch (TP: Q_DIM=512=4h×128, KV_DIM=128=1h×128) ──┐
50
+ │ RmsNorm(input_layernorm)
51
+ │ linear_hf q_proj / k_proj / v_proj → q, k, v
52
+ │ Per-head RmsNorm q_norm, k_norm
53
+ │ Fused RoPE: aclnnApplyRotaryPosEmbV2 (layout=1, "half")
54
+ │ Append K, V to per-layer KV cache
55
+ │ Mask selection:
56
+ │ prefill: 2048×2048 causal + sparse_mode=3
57
+ │ decode S=1: nullptr + sparse_mode=0
58
+ │ batch decode: [1,1,S,past+S] custom bool mask + sparse_mode=0
59
+ │ FIAS (aclnnFusedInferAttentionScore)
60
+ │ o_proj linear_hf → partial per-rank
61
+ │ HCCL AllReduce (ring + AIV + FFTS) → full
62
+ └─────────┘
63
+ ↓ residual add
64
+ ┌── MoE branch ──┐
65
+ │ RmsNorm(post_attention_layernorm)
66
+ │ router linear_hf → logits [S, 128]
67
+ │ moe_gating_topk_softmax → topk_w[S,8], topk_idx[S,8]
68
+ │ Device-side normalize (reduce_sum + adds + cast + div)
69
+ │ moe_init_routing_v3 → expanded_x, expanded_ri, tokens_per_expert
70
+ │ grouped_matmul_v4 gate/up/down (SwiGLU activation)
71
+ │ Device-side argsort × 2 → fwd permutation (avoids host sync)
72
+ │ IndexSelect → packed
73
+ │ Broadcast-mul by topk_w + ReduceSum axis=1
74
+ │ HCCL AllReduce → full
75
+ └─────────┘
76
+ ↓ residual add
77
+ x_out
78
+ ```
79
+
80
+ ---
81
+
82
+ ## Model weights
83
+
84
+ This project targets **Qwen3-235B-A22B-Instruct-2507** (BF16). About **470 GB** of safetensors shards.
85
+
86
+ **Download sources**:
87
+ - HuggingFace: https://huggingface.co/Qwen/Qwen3-235B-A22B-Instruct-2507
88
+ - ModelScope: https://www.modelscope.cn/models/Qwen/Qwen3-235B-A22B-Instruct-2507
89
+
90
+ Download via `huggingface-cli` or `modelscope` CLI:
91
+ ```bash
92
+ # HuggingFace
93
+ huggingface-cli download Qwen/Qwen3-235B-A22B-Instruct-2507 --local-dir /path/to/Qwen3-235B-A22B-Instruct-2507-BF16
94
+
95
+ # ModelScope
96
+ modelscope download --model Qwen/Qwen3-235B-A22B-Instruct-2507 --local_dir /path/to/Qwen3-235B-A22B-Instruct-2507-BF16
97
+ ```
98
+
99
+ **Weights format**: the binary reads HuggingFace `.safetensors` shards (multi-shard mmap), `config.json`, and `tokenizer.json` directly from the model directory. No conversion step is needed — point `--model-dir` at the downloaded directory.
100
+
101
+ **Expected directory contents**:
102
+ ```
103
+ Qwen3-235B-A22B-Instruct-2507-BF16/
104
+ ├── config.json
105
+ ├── tokenizer.json
106
+ ├── tokenizer_config.json
107
+ ├── model-00001-of-000XX.safetensors
108
+ ├── ...
109
+ └── model.safetensors.index.json
110
+ ```
111
+
112
+ ---
113
+
114
+ ## Build
115
+
116
+ ```bash
117
+ source /usr/local/Ascend/ascend-toolkit/set_env.sh
118
+ cmake -B build
119
+ cmake --build build -j8 --target qwen3-moe-aclnn
120
+ ```
121
+
122
+ **Requires**:
123
+ - CANN 8.5.1 or compatible
124
+ - Python 3 + `transformers` + `torch_npu` (for tokenizer subprocess and reference-data generation only)
125
+ - C++17 compiler
126
+ - Ascend 910 × 16 NPU
127
+ - nlohmann/json (bundled as `external/json.hpp`)
128
+
129
+ **Python environment setup** — the tokenizer calls a Python subprocess. Override the activation command via `QWEN3_PYENV_INIT` if your conda / venv layout differs from the default:
130
+ ```bash
131
+ export QWEN3_PYENV_INIT="source /opt/my_conda/etc/profile.d/conda.sh && conda activate my_env && "
132
+ ```
133
+ If unset, the default tries `${HOME}/miniconda3` with env `qwen3` and auto-sources the Ascend toolkit.
134
+
135
+ ---
136
+
137
+ ## Quick-start inference
138
+
139
+ ```bash
140
+ # 1. Export tokenizer vocab to binary (one-time setup)
141
+ python3 scripts/export_vocab.py /path/to/Qwen3-235B-A22B-Instruct-2507-BF16
142
+
143
+ # 2. Run inference (TP=16)
144
+ ./scripts/tp_launch.sh 16 ./build/qwen3-moe-aclnn \
145
+ --model-dir /path/to/Qwen3-235B-A22B-Instruct-2507-BF16 \
146
+ --prompt "The capital of France is" \
147
+ --n-predict 100 \
148
+ --temperature 0 \
149
+ --vocab tokenizer_data/vocab.bin
150
+ ```
151
+
152
+ Expected: ~27 t/s, coherent output.
153
+
154
+ ### Recommended flags by use case
155
+
156
+ **Universal default (stable, any prompt)** — no PLD:
157
+ ```bash
158
+ ./scripts/tp_launch.sh 16 ./build/qwen3-moe-aclnn --model-dir ... --temperature 0 --no-stream
159
+ ```
160
+
161
+ **Structured / long-form (essays, explanations)** — PLD with guard gives +60-90%:
162
+ ```bash
163
+ ./scripts/tp_launch.sh 16 ./build/qwen3-moe-aclnn --model-dir ... --pld --temperature 0 --no-stream
164
+ ```
165
+
166
+ **Interactive REPL (multi-turn chat)**:
167
+ ```bash
168
+ ./scripts/tp_launch.sh 16 ./build/qwen3-moe-aclnn --model-dir ... \
169
+ --interactive --chat --temperature 0.7 --top-p 0.8
170
+ ```
171
+
172
+ ---
173
+
174
+ ## PLD degeneration guard
175
+
176
+ Prompt Lookup Decoding speeds up generation by having the model verify a batch of "draft" tokens in a single forward pass. The drafts are copied from the generation history via n-gram match.
177
+
178
+ **Known failure mode**: on prompts the model tends to repeat on (factual Q&A, code generation), the n-gram match feeds the model's own repetition back as drafts, creating a positive feedback loop that accelerates degenerate output. Early versions of this project reported misleading peak TG numbers driven by this loop.
179
+
180
+ **This project's guard** blocks suspect drafts with two heuristics:
181
+
182
+ 1. **low-distinct**: draft's distinct-token count < threshold → reject
183
+ 2. **tail-echo**: all of last N hist tokens equal draft[0] → reject
184
+
185
+ Rejected drafts fall back to single-token decode. A `[warn]` line is emitted once if the generated tail shows 8 consecutive identical tokens.
186
+
187
+ Flags:
188
+ ```
189
+ --pld enable PLD (opt-in)
190
+ --pld-k N draft window size (default: 10)
191
+ --pld-ngram N n-gram match size (default: 1, with multi-level fallback)
192
+ --pld-min-hist N skip PLD until history >= N tokens (default: 20)
193
+ --pld-no-guard disable the degeneration guard (dangerous: can produce dead loops)
194
+ --pld-guard-distinct N minimum distinct tokens in draft (default: 3)
195
+ --pld-guard-tail N tail-echo window (default: 6)
196
+ --pld-loop-warn N emit warning on N consecutive identical tokens (default: 8)
197
+ ```
198
+
199
+ **Honest benchmarking**: use `scripts/bench_pld_safe.sh`, which classifies each run's output as OK / LOOP_N / LOW_DIVERSITY and separates TG statistics for OK-only vs degraded runs.
200
+
201
+ ---
202
+
203
+ ## Correctness verification
204
+
205
+ 15+ unit / integration tests checked against Python (HuggingFace Transformers) reference:
206
+
207
+ ```bash
208
+ ./build/test_attention_layer # rel=4.9e-4 vs Python prefill
209
+ ./build/test_attention_decode # rel=0 (bit-exact)
210
+ ./build/test_moe_layer # rel=3.6e-3
211
+ ./build/test_layer_forward # full single layer
212
+ ./build/test_runner # multi-layer runner
213
+ ./build/test_rope_fused # aclnnApplyRotaryPosEmbV2 vs manual HF rotate_half
214
+ ./build/test_batch_decode # S=1..8 timing
215
+ ./build/test_batch_correctness # argmax consistency
216
+ ./build/test_op_support # 910-specific op availability
217
+ # Integration smoke:
218
+ ./tests/test_chat_flow.sh # 7/7 PASS
219
+ ```
220
+
221
+ Tests expect reference data under `tests/<name>_data/` generated by `scripts/gen_*_reference.py`. See each script's docstring.
222
+
223
+ ---
224
+
225
+ ## Environment tuning (auto-applied by `tp_launch.sh`)
226
+
227
+ ```bash
228
+ HCCL_WHITELIST_DISABLE=1
229
+ HCCL_ALGO=level0:ring # ring, not fullmesh (fullmesh causes garbled output)
230
+ HCCL_BUFFSIZE=200 # sweet spot; 100 and 400 both slower
231
+ HCCL_OP_EXPANSION_MODE=AIV # key: AI Vector cores participate in reduce scheduling
232
+ HCCL_OP_BASE_FFTS_MODE_ENABLE=1 # key: Fast Frequently-used Transfer Scheduling
233
+ TASK_QUEUE_ENABLE=2 # key: aggressive async task submission
234
+ ```
235
+
236
+ Removing any of the three "key" env vars drops TG by 20-40%.
237
+
238
+ ---
239
+
240
+ ## Directory layout
241
+
242
+ ```
243
+ include/
244
+ ├── acl_common.h RAII wrappers, DeviceBuffer, make_contig_tensor
245
+ ├── aclnn_ops.h single-op wrappers + WorkspacePool integration
246
+ ├── acl_runtime.h AclRuntime (device + stream management)
247
+ ├── device_weights.h safetensors → device loading + TP sharding
248
+ ├── engine.h attention_forward + moe_forward + RopeCache
249
+ ├── hccl_comm.h HCCL init + allreduce + broadcast
250
+ ├── model_config.h Qwen3 hyperparameters + compute_derived
251
+ ├── rope.h apply_rope_fused (aclnnApplyRotaryPosEmbV2 wrapper)
252
+ ├── runner.h Runner class (prefill/decode/decode_batch/rewind/profile)
253
+ ├── safetensors_loader.h multi-shard safetensors mmap parser
254
+ ├── tokenizer.h vocab decode + Python subprocess encode
255
+ └── workspace_pool.h thread-local aclnn workspace pool (retain-old)
256
+
257
+ src/
258
+ ├── device_weights.cpp load_attention (GQA fix), load_moe (permute sync fix)
259
+ ├── main_cli.cpp CLI entry + PLD main loop + degeneration guard + multi-turn
260
+ ├── model_config.cpp compute_derived (GQA KV sharding)
261
+ ├── runner.cpp Runner (build_batch_decode_mask_ etc.)
262
+ ├── safetensors_loader.cpp
263
+ └── tokenizer.cpp
264
+
265
+ scripts/
266
+ ├── tp_launch.sh production launcher (auto-applies HCCL env)
267
+ ├── bench_tg.sh stable N-run TG measurement
268
+ ├── bench_pld_safe.sh PLD benchmark with output-correctness classifier
269
+ ├── bench_hccl[_adv].sh HCCL parameter sweep
270
+ ├── bench_pld[_k].sh PLD K × ngram sweep (legacy, prefer bench_pld_safe.sh)
271
+ ├── export_vocab.py vocab.bin exporter from HF tokenizer
272
+ └── gen_*_reference.py per-op Python reference data generators
273
+
274
+ tests/
275
+ ├── test_attention_* attention correctness (prefill / decode)
276
+ ├── test_moe_layer MoE correctness
277
+ ├── test_layer_forward full single layer
278
+ ├── test_runner multi-layer Runner
279
+ ├── test_rope_fused fused RoPE vs manual HF
280
+ ├── test_batch_* batch decode timing + correctness
281
+ ├── test_op_support 910-specific op availability probe
282
+ └── test_chat_flow.sh end-to-end integration smoke
283
+ ```
284
+
285
+ ---
286
+
287
+ ## CLI reference
288
+
289
+ ```
290
+ --model-dir <path> (required) HF safetensors directory
291
+ --prompt "<text>" prompt text
292
+ --prompt-file FILE read prompt from file (avoids shell-escape issues)
293
+ --n-predict N maximum tokens to generate
294
+ --tp-size N tensor parallelism (or set TP_SIZE env)
295
+ --max-seq N KV cache + context cap (default: 512)
296
+ --temperature F 0 = greedy; typical 0.7
297
+ --top-k N 0 = disabled
298
+ --top-p F 1.0 = disabled
299
+ --seed N 0 = time-based
300
+ --chat apply Qwen3 chat template
301
+ --system "<text>" system role text (with --chat)
302
+ --interactive, -i REPL mode (multi-turn memory with --chat)
303
+ --reset force stateless REPL (reset KV between turns)
304
+ --no-stream batch-print final text instead of per-token streaming
305
+ --vocab <path> vocab.bin path (default: tokenizer_data/vocab.bin)
306
+ --pld* see "PLD degeneration guard" section
307
+ ```
308
+
309
+ ---
310
+
311
+ ## Known limitations
312
+
313
+ - **Not yet reaching cann-recipes GE graph 54 t/s baseline** (currently ~27 t/s stable / up to ~45 t/s PLD).
314
+ Closing the gap requires one of: (a) real graph compilation, (b) fused collectives (`MatmulAllReduce`, `GroupedMatmulAllReduce`) which are absent on 910 initial-gen, (c) migration to 910B/A2/A3.
315
+ - **Only `tp_size` ∈ {1, 2, 4, 8, 16}** supported. Values that don't evenly divide 64 Q heads will error.
316
+ - **PLD on factual/code prompts is unreliable** — either produces baseline TG (guard rejects most drafts) or enters partial degeneration the classifier may not catch at low-severity. Use `bench_pld_safe.sh` to evaluate honestly.
317
+ - **Tokenizer requires Python subprocess** — adds ~1s startup for first encode. Override via `QWEN3_PYENV_INIT` env if default conda path doesn't match.
318
+ - **NPU performance has high run-to-run variance** (up to 4× in some configurations) due to BF16 + MoE intrinsic non-determinism and shared hardware resources. Report medians over ≥5 runs.
319
+
320
+ ---
321
+
322
+ ## Future directions (prioritized)
323
+
324
+ 1. **Draft Model Speculative Decoding** with Qwen3-0.6B — more stable accept rate than n-gram PLD, expected +60-100% TG across prompt types (1-2 week implementation).
325
+ 2. **HCCL AllReduce / compute overlap** — ~+10-15% in theory, limited by EAGER path serial dependencies.
326
+ 3. **KV cache INT8 quantization** — reduces memory-bandwidth pressure, ~+15-25% on long contexts (pending 910-initial-gen op support verification).
327
+ 4. **W8 weight quantization** — ~+10-20% if aclnn quantization kernels exist on 910 initial-gen.
328
+
329
+ Not recommended:
330
+ - `aclmdlRI` stream-capture-style graph recording (POC proved 1.13× ceiling, not worth the engineering cost).
331
+ - Custom AscendC fused ops (high maintenance cost unless dedicated kernel engineer).
332
+ - torchair / torch.compile migration (breaks pure-C++ design).
333
+
334
+ ---
335
+
336
+ ## License
337
+
338
+ Apache License 2.0 — see `LICENSE`.
external/json.hpp ADDED
The diff for this file is too large to render. See raw diff
 
include/acl_common.h ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ #include <acl/acl.h>
3
+ #include <aclnn/acl_meta.h>
4
+ #include <cstdio>
5
+ #include <cstdlib>
6
+ #include <memory>
7
+ #include <string>
8
+ #include <vector>
9
+
10
+ #define ACL_CHECK(x) do { \
11
+ aclError __e = (x); \
12
+ if (__e != ACL_ERROR_NONE) { \
13
+ fprintf(stderr, "ACL error %d at %s:%d : %s\n", __e, __FILE__, __LINE__, #x); \
14
+ std::abort(); \
15
+ } \
16
+ } while(0)
17
+
18
+ #define ACLNN_CHECK(x) do { \
19
+ aclnnStatus __e = (x); \
20
+ if (__e != 0) { \
21
+ const char* __msg = aclGetRecentErrMsg(); \
22
+ fprintf(stderr, "aclnn error %d at %s:%d : %s\n msg: %s\n", (int)__e, __FILE__, __LINE__, #x, __msg ? __msg : "(null)"); \
23
+ std::abort(); \
24
+ } \
25
+ } while(0)
26
+
27
+ // RAII wrapper for aclTensor: call aclDestroyTensor on dtor
28
+ struct AclTensorDel { void operator()(aclTensor* t) const { if (t) aclDestroyTensor(t); } };
29
+ using AclTensorPtr = std::unique_ptr<aclTensor, AclTensorDel>;
30
+
31
+ struct AclTensorListDel { void operator()(aclTensorList* t) const { if (t) aclDestroyTensorList(t); } };
32
+ using AclTensorListPtr = std::unique_ptr<aclTensorList, AclTensorListDel>;
33
+
34
+ struct AclIntArrayDel { void operator()(aclIntArray* a) const { if (a) aclDestroyIntArray(a); } };
35
+ using AclIntArrayPtr = std::unique_ptr<aclIntArray, AclIntArrayDel>;
36
+
37
+ // Create ACL tensor with explicit row-major shape (outermost leftmost) and element strides.
38
+ // NOTE: stride is in ELEMENTS, not bytes.
39
+ inline AclTensorPtr make_acl_tensor(void* data, aclDataType dt,
40
+ const std::vector<int64_t>& shape,
41
+ const std::vector<int64_t>& stride_elems,
42
+ aclFormat fmt = ACL_FORMAT_ND) {
43
+ int64_t n = (int64_t)shape.size();
44
+ int64_t storage_len = 1;
45
+ for (int i = 0; i < n; i++) storage_len += (shape[i] - 1) * stride_elems[i];
46
+ aclTensor* t = aclCreateTensor(
47
+ shape.data(), (uint64_t)n, dt,
48
+ stride_elems.data(), 0, fmt,
49
+ &storage_len, 1, data);
50
+ return AclTensorPtr(t);
51
+ }
52
+
53
+ // Default contiguous strides for row-major tensor: stride[i] = product of shape[i+1..n-1]
54
+ inline std::vector<int64_t> contiguous_strides(const std::vector<int64_t>& shape) {
55
+ int n = (int)shape.size();
56
+ std::vector<int64_t> s(n);
57
+ int64_t acc = 1;
58
+ for (int i = n - 1; i >= 0; --i) {
59
+ s[i] = acc;
60
+ acc *= shape[i];
61
+ }
62
+ return s;
63
+ }
64
+
65
+ inline AclTensorPtr make_contig_tensor(void* data, aclDataType dt,
66
+ const std::vector<int64_t>& shape,
67
+ aclFormat fmt = ACL_FORMAT_ND) {
68
+ return make_acl_tensor(data, dt, shape, contiguous_strides(shape), fmt);
69
+ }
70
+
71
+ inline size_t dtype_size(aclDataType dt) {
72
+ switch (dt) {
73
+ case ACL_FLOAT: return 4;
74
+ case ACL_FLOAT16: return 2;
75
+ case ACL_BF16: return 2;
76
+ case ACL_INT8: return 1;
77
+ case ACL_INT32: return 4;
78
+ case ACL_INT64: return 8;
79
+ default: return 0;
80
+ }
81
+ }
82
+
83
+ // Device buffer RAII: allocates via aclrtMalloc, frees in dtor
84
+ struct DeviceBuffer {
85
+ void* ptr = nullptr;
86
+ size_t size = 0;
87
+
88
+ DeviceBuffer() = default;
89
+ explicit DeviceBuffer(size_t bytes) { alloc(bytes); }
90
+ ~DeviceBuffer() { if (ptr) aclrtFree(ptr); }
91
+ DeviceBuffer(const DeviceBuffer&) = delete;
92
+ DeviceBuffer& operator=(const DeviceBuffer&) = delete;
93
+ DeviceBuffer(DeviceBuffer&& o) noexcept : ptr(o.ptr), size(o.size) { o.ptr = nullptr; o.size = 0; }
94
+ DeviceBuffer& operator=(DeviceBuffer&& o) noexcept {
95
+ if (this != &o) { if (ptr) aclrtFree(ptr); ptr = o.ptr; size = o.size; o.ptr = nullptr; o.size = 0; }
96
+ return *this;
97
+ }
98
+
99
+ void alloc(size_t bytes) {
100
+ if (ptr) aclrtFree(ptr);
101
+ ACL_CHECK(aclrtMalloc(&ptr, bytes, ACL_MEM_MALLOC_HUGE_FIRST));
102
+ size = bytes;
103
+ }
104
+ void* get() { return ptr; }
105
+ const void* get() const { return ptr; }
106
+ };
include/acl_runtime.h ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // acl_runtime.h — per-rank ACL runtime init/teardown.
2
+ #pragma once
3
+ #include "acl_common.h"
4
+ #include <cstdio>
5
+
6
+ class AclRuntime {
7
+ public:
8
+ AclRuntime() = default;
9
+ ~AclRuntime() { shutdown(); }
10
+
11
+ bool init(int device_id) {
12
+ if (initialized_) return true;
13
+ device_id_ = device_id;
14
+ ACL_CHECK(aclInit(nullptr));
15
+ ACL_CHECK(aclrtSetDevice(device_id));
16
+ ACL_CHECK(aclrtCreateContext(&ctx_, device_id));
17
+ ACL_CHECK(aclrtCreateStream(&stream_));
18
+ initialized_ = true;
19
+ return true;
20
+ }
21
+
22
+ void shutdown() {
23
+ if (!initialized_) return;
24
+ if (stream_) { aclrtDestroyStream(stream_); stream_ = nullptr; }
25
+ if (ctx_) { aclrtDestroyContext(ctx_); ctx_ = nullptr; }
26
+ aclrtResetDevice(device_id_);
27
+ aclFinalize();
28
+ initialized_ = false;
29
+ }
30
+
31
+ void sync() { if (stream_) ACL_CHECK(aclrtSynchronizeStream(stream_)); }
32
+
33
+ aclrtStream stream() const { return stream_; }
34
+ int device_id() const { return device_id_; }
35
+
36
+ private:
37
+ bool initialized_ = false;
38
+ int device_id_ = 0;
39
+ aclrtContext ctx_ = nullptr;
40
+ aclrtStream stream_ = nullptr;
41
+ };
include/aclnn_ops.h ADDED
@@ -0,0 +1,345 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // aclnn_ops.h — thin wrappers around common aclnn operators used in forward pass.
2
+ // Each wrapper does GetWorkspaceSize + op call on the provided stream.
3
+ //
4
+ // All tensors are passed as raw aclTensor* (caller owns them).
5
+ // Workspace allocation uses DeviceBuffer (RAII).
6
+ #pragma once
7
+ #include "acl_common.h"
8
+ #include "workspace_pool.h"
9
+
10
+ // Thread-local shared workspace pool for all aclnn wrappers below. Single-threaded stream
11
+ // means we can safely reuse one buffer across serial op calls. Set via `GGML_CANN_WP=0` is
12
+ // not supported here — if truly needed, we'd wire a flag.
13
+ inline WorkspacePool& _lca_pool() {
14
+ thread_local WorkspacePool pool;
15
+ return pool;
16
+ }
17
+
18
+ #include <aclnnop/aclnn_add.h>
19
+ #include <aclnnop/aclnn_addcmul.h>
20
+ #include <aclnnop/aclnn_grouped_matmul_v4.h>
21
+ #include <aclnnop/aclnn_moe_finalize_routing.h>
22
+ #include <aclnnop/aclnn_moe_finalize_routing_v2.h>
23
+ #include <aclnnop/aclnn_moe_gating_top_k_softmax.h>
24
+ #include <aclnnop/aclnn_moe_init_routing_v3.h>
25
+ #include <aclnnop/aclnn_cast.h>
26
+ #include <aclnnop/aclnn_copy.h>
27
+ #include <aclnnop/aclnn_div.h>
28
+ #include <aclnnop/aclnn_fused_infer_attention_score.h>
29
+ #include <aclnnop/aclnn_index_select.h>
30
+ #include <aclnnop/aclnn_matmul.h>
31
+ #include <aclnnop/aclnn_mul.h>
32
+ #include <aclnnop/aclnn_neg.h>
33
+ #include <aclnnop/aclnn_reduce_sum.h>
34
+ #include <aclnnop/aclnn_silu.h>
35
+
36
+ // ---- RmsNorm ----
37
+ // Signature (based on ggml-cann usage): aclnnRmsNorm(x, gamma, eps, y, rstd)
38
+ // where rstd (rsqrt of mean-square) is an extra output we usually discard.
39
+
40
+ // Forward declare header; include happens in impl file to keep this header light.
41
+ extern "C" {
42
+ #include <aclnnop/aclnn_rms_norm.h>
43
+ }
44
+
45
+ inline void rms_norm(aclrtStream stream,
46
+ aclTensor* x, // [N, D] BF16/FP16
47
+ aclTensor* gamma, // [D] same dtype as x
48
+ double eps,
49
+ aclTensor* y, // [N, D]
50
+ aclTensor* rstd // [N] fp32 (required output)
51
+ ) {
52
+ uint64_t ws = 0;
53
+ aclOpExecutor* exec = nullptr;
54
+ ACLNN_CHECK(aclnnRmsNormGetWorkspaceSize(x, gamma, eps, y, rstd, &ws, &exec));
55
+ void* wp = (ws > 0) ? _lca_pool().alloc(ws) : nullptr;
56
+ ACLNN_CHECK(aclnnRmsNorm(wp, ws, exec, stream));
57
+ }
58
+
59
+ // ---- Silu ----
60
+ inline void silu(aclrtStream stream, aclTensor* x, aclTensor* y) {
61
+ uint64_t ws = 0;
62
+ aclOpExecutor* exec = nullptr;
63
+ ACLNN_CHECK(aclnnSiluGetWorkspaceSize(x, y, &ws, &exec));
64
+ void* wp = (ws > 0) ? _lca_pool().alloc(ws) : nullptr;
65
+ ACLNN_CHECK(aclnnSilu(wp, ws, exec, stream));
66
+ }
67
+
68
+ // ---- Mul (element-wise) ----
69
+ inline void mul(aclrtStream stream, aclTensor* a, aclTensor* b, aclTensor* out) {
70
+ uint64_t ws = 0;
71
+ aclOpExecutor* exec = nullptr;
72
+ ACLNN_CHECK(aclnnMulGetWorkspaceSize(a, b, out, &ws, &exec));
73
+ void* wp = (ws > 0) ? _lca_pool().alloc(ws) : nullptr;
74
+ ACLNN_CHECK(aclnnMul(wp, ws, exec, stream));
75
+ }
76
+
77
+ // ---- Cast ----
78
+ inline void cast(aclrtStream stream, aclTensor* x, aclDataType dst_dtype, aclTensor* y) {
79
+ uint64_t ws = 0;
80
+ aclOpExecutor* exec = nullptr;
81
+ ACLNN_CHECK(aclnnCastGetWorkspaceSize(x, dst_dtype, y, &ws, &exec));
82
+ void* wp = (ws > 0) ? _lca_pool().alloc(ws) : nullptr;
83
+ ACLNN_CHECK(aclnnCast(wp, ws, exec, stream));
84
+ }
85
+
86
+ // ---- InplaceCopy: copy src (possibly non-contiguous via strides) into contiguous dst ----
87
+ inline void inplace_copy(aclrtStream stream, aclTensor* dst, aclTensor* src) {
88
+ uint64_t ws = 0;
89
+ aclOpExecutor* exec = nullptr;
90
+ ACLNN_CHECK(aclnnInplaceCopyGetWorkspaceSize(dst, src, &ws, &exec));
91
+ void* wp = (ws > 0) ? _lca_pool().alloc(ws) : nullptr;
92
+ ACLNN_CHECK(aclnnInplaceCopy(wp, ws, exec, stream));
93
+ }
94
+
95
+ // ---- Matmul: out = a @ b ----
96
+ // cube_math_type:
97
+ // 0 = KEEP_DTYPE, 1 = ALLOW_FP32_DOWN_PRECISION, 2 = USE_FP16, 3 = USE_HF32
98
+ inline void matmul(aclrtStream stream,
99
+ aclTensor* a, aclTensor* b, aclTensor* out,
100
+ int8_t cube_math_type = 1) {
101
+ uint64_t ws = 0;
102
+ aclOpExecutor* exec = nullptr;
103
+ ACLNN_CHECK(aclnnMatmulGetWorkspaceSize(a, b, out, cube_math_type, &ws, &exec));
104
+ void* wp = (ws > 0) ? _lca_pool().alloc(ws) : nullptr;
105
+ ACLNN_CHECK(aclnnMatmul(wp, ws, exec, stream));
106
+ }
107
+
108
+ // ---- Neg ----
109
+ inline void neg(aclrtStream stream, aclTensor* x, aclTensor* y) {
110
+ uint64_t ws = 0;
111
+ aclOpExecutor* exec = nullptr;
112
+ ACLNN_CHECK(aclnnNegGetWorkspaceSize(x, y, &ws, &exec));
113
+ void* wp = (ws > 0) ? _lca_pool().alloc(ws) : nullptr;
114
+ ACLNN_CHECK(aclnnNeg(wp, ws, exec, stream));
115
+ }
116
+
117
+ // ---- Addcmul: self = self + value * (tensor1 * tensor2) ----
118
+ inline void addcmul(aclrtStream stream, aclTensor* self_io, aclTensor* t1, aclTensor* t2, float value) {
119
+ aclScalar* v = aclCreateScalar(&value, ACL_FLOAT);
120
+ uint64_t ws = 0;
121
+ aclOpExecutor* exec = nullptr;
122
+ ACLNN_CHECK(aclnnAddcmulGetWorkspaceSize(self_io, t1, t2, v, self_io, &ws, &exec));
123
+ void* wp = (ws > 0) ? _lca_pool().alloc(ws) : nullptr;
124
+ ACLNN_CHECK(aclnnAddcmul(wp, ws, exec, stream));
125
+ aclDestroyScalar(v);
126
+ }
127
+
128
+ // ---- MoE Gating TopK Softmax ----
129
+ // x [N, E] → y [N, K] (top-K softmax probs), expert_idx [N, K] int32, row_idx [N, K] int32
130
+ inline void moe_gating_topk_softmax(aclrtStream stream,
131
+ aclTensor* x, int64_t k,
132
+ aclTensor* y_out, aclTensor* idx_out, aclTensor* row_idx_out) {
133
+ uint64_t ws = 0;
134
+ aclOpExecutor* exec = nullptr;
135
+ ACLNN_CHECK(aclnnMoeGatingTopKSoftmaxGetWorkspaceSize(x, nullptr, k, y_out, idx_out, row_idx_out, &ws, &exec));
136
+ void* wp = (ws > 0) ? _lca_pool().alloc(ws) : nullptr;
137
+ ACLNN_CHECK(aclnnMoeGatingTopKSoftmax(wp, ws, exec, stream));
138
+ }
139
+
140
+ // ---- MoE Init Routing V3 ----
141
+ // x [N, D], expert_idx [N, K] int32 → expanded_x [N*K, D], expanded_row_idx [N*K] int32,
142
+ // tokens_per_expert [E] int64
143
+ inline void moe_init_routing_v3(aclrtStream stream,
144
+ aclTensor* x, aclTensor* expert_idx,
145
+ int64_t n_experts, int64_t active_num,
146
+ aclTensor* expanded_x, aclTensor* expanded_row_idx,
147
+ aclTensor* tokens_per_expert)
148
+ {
149
+ int64_t range[2] = {0, n_experts};
150
+ aclIntArray* r = aclCreateIntArray(range, 2);
151
+ // scale_out_optional we dummy since quant_mode=-1 (no quant) still requires pass a placeholder?
152
+ // Per our POC test earlier: pass a real tensor for scale_out works.
153
+ // For simplicity here, we'll allocate a dummy [active_num] float tensor.
154
+ DeviceBuffer dummy(active_num * 4);
155
+ auto t_dummy = make_contig_tensor(dummy.get(), ACL_FLOAT, {active_num});
156
+
157
+ uint64_t ws = 0; aclOpExecutor* exec = nullptr;
158
+ // rowIdxType=1: expanded_row_idx[i] = sorted_position p for i-th original (n,k) flat index.
159
+ // This lets us use expanded_row_idx directly as the gather index (forward permutation).
160
+ ACLNN_CHECK(aclnnMoeInitRoutingV3GetWorkspaceSize(
161
+ x, expert_idx, nullptr, nullptr,
162
+ active_num, 0, n_experts, 0, 1, true, -1,
163
+ r, 1,
164
+ expanded_x, expanded_row_idx, tokens_per_expert, t_dummy.get(),
165
+ &ws, &exec));
166
+ void* wp = (ws > 0) ? _lca_pool().alloc(ws) : nullptr;
167
+ ACLNN_CHECK(aclnnMoeInitRoutingV3(wp, ws, exec, stream));
168
+ aclDestroyIntArray(r);
169
+ }
170
+
171
+ // ---- GroupedMatmulV4 (single-in single-out, M-axis split) ----
172
+ // x [T, K_in], w [E, K_in, N_out] contiguous row-major, group_list [E] int64 → y [T, N_out]
173
+ // group_list_type: 0=cumsum, 1=counts (V4 doc)
174
+ inline void grouped_matmul_v4(aclrtStream stream,
175
+ aclTensor* x, aclTensor* w, aclTensor* group_list, aclTensor* y,
176
+ int64_t group_list_type = 1)
177
+ {
178
+ aclTensor* xa[] = {x}; aclTensorList* x_list = aclCreateTensorList(xa, 1);
179
+ aclTensor* wa[] = {w}; aclTensorList* w_list = aclCreateTensorList(wa, 1);
180
+ aclTensor* ya[] = {y}; aclTensorList* y_list = aclCreateTensorList(ya, 1);
181
+
182
+ uint64_t ws = 0; aclOpExecutor* exec = nullptr;
183
+ ACLNN_CHECK(aclnnGroupedMatmulV4GetWorkspaceSize(
184
+ x_list, w_list,
185
+ nullptr, nullptr, nullptr, nullptr, nullptr, nullptr,
186
+ group_list,
187
+ nullptr, nullptr, nullptr,
188
+ 3, 0, group_list_type, 0,
189
+ y_list, nullptr, nullptr,
190
+ &ws, &exec));
191
+ void* wp = (ws > 0) ? _lca_pool().alloc(ws) : nullptr;
192
+ ACLNN_CHECK(aclnnGroupedMatmulV4(wp, ws, exec, stream));
193
+ // NOTE: TensorList takes ownership of the raw tensors. Destroying the list frees them,
194
+ // which would cause double-free in the caller's AclTensorPtr. Leak the list (small cost).
195
+ // A cleaner API would accept (ptr, shape, dtype) triples and build tensors internally.
196
+ // TODO(M6): refactor for long-running use.
197
+ }
198
+
199
+ // ---- MoE Finalize Routing V2: out = x1 + weighted_sum of top-K outputs ----
200
+ // V2 has all inputs optional except expandedX/expandedRowIdx/out; pass nullptr for x1 to
201
+ // skip the residual add, or pass the residual to fuse it into this op.
202
+ inline void moe_finalize_routing(aclrtStream stream,
203
+ aclTensor* expanded_x,
204
+ aclTensor* x1_skip, // [N, D] added to output (nullable)
205
+ aclTensor* scales, // weights [N, K]
206
+ aclTensor* expanded_row_idx,
207
+ aclTensor* expert_idx, // [N, K] topk expert indices (nullable)
208
+ aclTensor* out)
209
+ {
210
+ uint64_t ws = 0; aclOpExecutor* exec = nullptr;
211
+ ACLNN_CHECK(aclnnMoeFinalizeRoutingV2GetWorkspaceSize(
212
+ expanded_x,
213
+ expanded_row_idx,
214
+ x1_skip, // x1Optional
215
+ nullptr, // x2Optional
216
+ nullptr, // biasOptional
217
+ scales, // scalesOptional
218
+ expert_idx, // expertIdxOptional (needed for correct routing)
219
+ 0, // dropPadMode (0 = dropless, which matches our pipeline)
220
+ out,
221
+ &ws, &exec));
222
+ void* wp = (ws > 0) ? _lca_pool().alloc(ws) : nullptr;
223
+ ACLNN_CHECK(aclnnMoeFinalizeRoutingV2(wp, ws, exec, stream));
224
+ }
225
+
226
+ // ---- Div: self / other (broadcast supported) ----
227
+ inline void div_tensor(aclrtStream stream, aclTensor* self, aclTensor* other, aclTensor* out) {
228
+ uint64_t ws = 0; aclOpExecutor* exec = nullptr;
229
+ ACLNN_CHECK(aclnnDivGetWorkspaceSize(self, other, out, &ws, &exec));
230
+ void* wp = (ws > 0) ? _lca_pool().alloc(ws) : nullptr;
231
+ ACLNN_CHECK(aclnnDiv(wp, ws, exec, stream));
232
+ }
233
+
234
+ // ---- In-place scalar add: self += scalar ----
235
+ #include <aclnnop/aclnn_add.h>
236
+ #include <aclnnop/aclnn_argsort.h>
237
+
238
+ // ---- Argsort: indices that would sort self along dim (returns INT64) ----
239
+ inline void argsort(aclrtStream stream, aclTensor* self, int64_t dim, bool descending,
240
+ aclTensor* indices_out) {
241
+ uint64_t ws = 0; aclOpExecutor* exec = nullptr;
242
+ ACLNN_CHECK(aclnnArgsortGetWorkspaceSize(self, dim, descending, indices_out, &ws, &exec));
243
+ void* wp = (ws > 0) ? _lca_pool().alloc(ws) : nullptr;
244
+ ACLNN_CHECK(aclnnArgsort(wp, ws, exec, stream));
245
+ }
246
+
247
+ inline void inplace_adds(aclrtStream stream, aclTensor* self, double value) {
248
+ float v = (float)value;
249
+ aclScalar* s = aclCreateScalar(&v, ACL_FLOAT);
250
+ float alpha_v = 1.0f;
251
+ aclScalar* al = aclCreateScalar(&alpha_v, ACL_FLOAT);
252
+ uint64_t ws = 0; aclOpExecutor* exec = nullptr;
253
+ ACLNN_CHECK(aclnnInplaceAddsGetWorkspaceSize(self, s, al, &ws, &exec));
254
+ void* wp = (ws > 0) ? _lca_pool().alloc(ws) : nullptr;
255
+ ACLNN_CHECK(aclnnInplaceAdds(wp, ws, exec, stream));
256
+ aclDestroyScalar(s);
257
+ aclDestroyScalar(al);
258
+ }
259
+
260
+ // ---- ReduceSum over specified dims ----
261
+ inline void reduce_sum(aclrtStream stream, aclTensor* self, const std::vector<int64_t>& dims,
262
+ bool keep_dims, aclDataType out_dtype, aclTensor* out) {
263
+ aclIntArray* d = aclCreateIntArray(dims.data(), dims.size());
264
+ uint64_t ws = 0; aclOpExecutor* exec = nullptr;
265
+ ACLNN_CHECK(aclnnReduceSumGetWorkspaceSize(self, d, keep_dims, out_dtype, out, &ws, &exec));
266
+ void* wp = (ws > 0) ? _lca_pool().alloc(ws) : nullptr;
267
+ ACLNN_CHECK(aclnnReduceSum(wp, ws, exec, stream));
268
+ aclDestroyIntArray(d);
269
+ }
270
+
271
+ // ---- IndexSelect: out[j] = self[index[j], ...] ----
272
+ inline void index_select(aclrtStream stream, aclTensor* self, int64_t dim, aclTensor* index, aclTensor* out) {
273
+ uint64_t ws = 0;
274
+ aclOpExecutor* exec = nullptr;
275
+ ACLNN_CHECK(aclnnIndexSelectGetWorkspaceSize(self, dim, index, out, &ws, &exec));
276
+ void* wp = (ws > 0) ? _lca_pool().alloc(ws) : nullptr;
277
+ ACLNN_CHECK(aclnnIndexSelect(wp, ws, exec, stream));
278
+ }
279
+
280
+ // ---- FusedInferAttentionScore (simplified wrapper for prefill/decode without quant, BSH layout).
281
+ // Caller owns q/k/v/mask/out; k/v are single-tensor lists.
282
+ inline void fused_infer_attention_score(
283
+ aclrtStream stream,
284
+ aclTensor* q, // [B, S, Hq*Dh] BF16
285
+ aclTensor* k, // [B, S, Hkv*Dh] BF16
286
+ aclTensor* v, // [B, S, Hkv*Dh] BF16
287
+ aclTensor* atten_mask, // [1, 1, M, M] bool, sparse_mode=3 needs M=2048
288
+ std::vector<int64_t> actual_seq_lens,
289
+ std::vector<int64_t> actual_seq_lens_kv,
290
+ int64_t num_heads, int64_t num_kv_heads,
291
+ double scale, int64_t sparse_mode,
292
+ aclTensor* out) // [B, S, Hq*Dh]
293
+ {
294
+ aclTensor* k_arr[] = {k};
295
+ aclTensor* v_arr[] = {v};
296
+ aclTensorList* k_list = aclCreateTensorList(k_arr, 1);
297
+ aclTensorList* v_list = aclCreateTensorList(v_arr, 1);
298
+ aclIntArray* sq = aclCreateIntArray(actual_seq_lens.data(), (uint64_t)actual_seq_lens.size());
299
+ aclIntArray* skv = aclCreateIntArray(actual_seq_lens_kv.data(), (uint64_t)actual_seq_lens_kv.size());
300
+
301
+ uint64_t ws = 0;
302
+ aclOpExecutor* exec = nullptr;
303
+ ACLNN_CHECK(aclnnFusedInferAttentionScoreGetWorkspaceSize(
304
+ q, k_list, v_list,
305
+ nullptr, // pseShift
306
+ atten_mask,
307
+ sq, skv,
308
+ nullptr, nullptr, nullptr, nullptr, nullptr, // dequant/quant scales
309
+ nullptr, nullptr, // antiquant
310
+ nullptr, nullptr, nullptr, // block_table, q_padding, kv_padding
311
+ num_heads,
312
+ scale,
313
+ 2147483647, 2147483647, // pre/next tokens (no limit)
314
+ (char*)"BSH",
315
+ num_kv_heads,
316
+ sparse_mode,
317
+ 0, // inner_precise
318
+ 0, 0, // block_size, antiquant_mode
319
+ false, // softmax_lse_flag
320
+ out, nullptr,
321
+ &ws, &exec));
322
+
323
+ void* wp = (ws > 0) ? _lca_pool().alloc(ws) : nullptr;
324
+ ACLNN_CHECK(aclnnFusedInferAttentionScore(wp, ws, exec, stream));
325
+ // See note on grouped_matmul_v4 — intentionally leak lists to avoid double-free with caller RAII.
326
+ (void)k_list; (void)v_list;
327
+ aclDestroyIntArray(sq);
328
+ aclDestroyIntArray(skv);
329
+ }
330
+
331
+ // ---- "Linear" helper: y = x @ W.T where W is stored as [out_features, in_features] (HF convention).
332
+ // Achieved by viewing W as [in_features, out_features] with stride [1, in_features] (elements).
333
+ // Returns y [N, out_features].
334
+ // Caller allocates y.
335
+ inline void linear_hf(aclrtStream stream,
336
+ aclTensor* x, // [N, in_features]
337
+ void* W_data, aclDataType dtype,
338
+ int64_t out_features, int64_t in_features,
339
+ aclTensor* y_out) // [N, out_features]
340
+ {
341
+ auto W_view = make_acl_tensor(W_data, dtype,
342
+ {in_features, out_features},
343
+ {1, in_features}); // strides: d0=1 elem, d1=in_features elems
344
+ matmul(stream, x, W_view.get(), y_out);
345
+ }
include/device_weights.h ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // device_weights.h — load safetensors weights to device memory with proper TP shard.
2
+ //
3
+ // For M3 (attention only): loads attention + norm weights. MoE expert weights come in M4.
4
+ //
5
+ #pragma once
6
+ #include "acl_common.h"
7
+ #include "model_config.h"
8
+ #include "safetensors_loader.h"
9
+
10
+ #include <string>
11
+ #include <unordered_map>
12
+ #include <vector>
13
+
14
+ // Per-layer MoE weights on device (BF16).
15
+ // After loading: weights are in GMM-ready layout [E, K_in, N_out] row-major contiguous.
16
+ // For gate/up: K_in=D, N_out=I_per_rank
17
+ // For down: K_in=I, N_out=D
18
+ struct LayerMoEWeights {
19
+ DeviceBuffer router; // [E, D] BF16 replicated
20
+ DeviceBuffer gate_exps; // [E, D, I_per_rank] (permuted from HF [E, I, D])
21
+ DeviceBuffer up_exps; // [E, D, I_per_rank]
22
+ DeviceBuffer down_exps; // [E, I_per_rank, D] (permuted from HF [E, D, I])
23
+ };
24
+
25
+ // Per-layer attention weights on device (BF16 unless noted).
26
+ struct LayerAttnWeights {
27
+ DeviceBuffer input_layernorm; // [D] BF16
28
+ DeviceBuffer post_attention_layernorm; // [D] BF16
29
+ // Q/K/V/O projections. HF stores as [out, in] BF16.
30
+ // For M3 we keep HF layout as-is; matmul wrappers handle the transpose via aclnnMm semantics.
31
+ DeviceBuffer q_proj; // [Q_full, D] on rank, but physical stored as [Q_rank, D] (sliced by head)
32
+ DeviceBuffer k_proj; // [KV, D] (replicated if tp_size > num_kv_heads)
33
+ DeviceBuffer v_proj; // [KV, D]
34
+ DeviceBuffer o_proj; // [D, Q_rank] (row-parallel on Q dim)
35
+ DeviceBuffer q_norm; // [head_dim] BF16 (Qwen3 per-head norm)
36
+ DeviceBuffer k_norm; // [head_dim] BF16
37
+ };
38
+
39
+ // Shared model weights (replicated across ranks).
40
+ struct SharedWeights {
41
+ DeviceBuffer embed_tokens; // [vocab, D]
42
+ DeviceBuffer lm_head; // [vocab, D]
43
+ DeviceBuffer final_norm; // [D]
44
+ };
45
+
46
+ class DeviceWeightsLoader {
47
+ public:
48
+ DeviceWeightsLoader(SafetensorsLoader& st, const ModelConfig& cfg)
49
+ : st_(st), cfg_(cfg) {}
50
+
51
+ // Load shared (embed, norm, lm_head). Replicated on every rank.
52
+ bool load_shared(SharedWeights& out);
53
+
54
+ // Load ONE attention layer's weights with TP sharding.
55
+ bool load_attention(int layer_idx, LayerAttnWeights& out);
56
+
57
+ // Load ONE MoE layer's weights. Stacks 128 experts and permutes to GMM-ready layout.
58
+ // stream: ACL stream for the permute op (aclnnInplaceCopy).
59
+ bool load_moe(int layer_idx, aclrtStream stream, LayerMoEWeights& out);
60
+
61
+ // Expose underlying safetensors for direct access (diagnostic use).
62
+ SafetensorsLoader& st() { return st_; }
63
+
64
+ private:
65
+ SafetensorsLoader& st_;
66
+ const ModelConfig& cfg_;
67
+
68
+ // Helper: load HF tensor (full shape) into device buffer (simple H2D).
69
+ bool load_tensor_full_(const std::string& name, DeviceBuffer& buf);
70
+
71
+ // Helper: load HF tensor and keep only [row_lo, row_hi) of first dim (TP shard by "out" dim).
72
+ // HF format: tensor has shape [D0, D1, ...] stored row-major. We take rows [lo, hi) to form
73
+ // a sharded tensor of shape [hi-lo, D1, ...].
74
+ bool load_tensor_row_slice_(const std::string& name,
75
+ int64_t row_lo, int64_t row_hi,
76
+ DeviceBuffer& buf);
77
+
78
+ // TP shard by "in" dim (second axis for 2D, etc.) — used for o_proj (row-parallel).
79
+ bool load_tensor_col_slice_(const std::string& name,
80
+ int64_t col_lo, int64_t col_hi,
81
+ DeviceBuffer& buf);
82
+ };
include/engine.h ADDED
@@ -0,0 +1,354 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // engine.h — single-layer forward functions for attention and MoE.
2
+ //
3
+ // Both functions operate on device tensors. The caller owns all buffers (input, output, weights,
4
+ // KV cache slots, scratch). They take RoPE cos/sin tables and act as pure forward kernels.
5
+ //
6
+ // Design goals:
7
+ // - Zero allocations per call (all scratch is passed in)
8
+ // - Same signature works for prefill (S>=1) and decode (S=1); caller picks sparse_mode.
9
+ // - Residual connection is NOT included (caller decides when to add residual).
10
+ #pragma once
11
+ #include "acl_common.h"
12
+ #include "aclnn_ops.h"
13
+ #include "device_weights.h"
14
+ #include "hccl_comm.h"
15
+ #include "model_config.h"
16
+ #include "rope.h"
17
+
18
+ #include <algorithm>
19
+ #include <cmath>
20
+ #include <cstring>
21
+ #include <tuple>
22
+ #include <vector>
23
+
24
+ // Bf16 conversion helpers used by fill_cos_sin.
25
+ static inline uint16_t _engine_f2bf16(float x) {
26
+ uint32_t u; std::memcpy(&u, &x, 4);
27
+ return (uint16_t)((u + 0x7FFF + ((u >> 16) & 1)) >> 16);
28
+ }
29
+
30
+ // Fill cos/sin tables for positions [p0, p0+L) with HF half-half layout. Returns
31
+ // contiguous [L*Dh] BF16 in provided host vectors (caller uploads to device).
32
+ inline void fill_cos_sin_hf(std::vector<uint16_t>& cos_h, std::vector<uint16_t>& sin_h,
33
+ int64_t p0, int64_t L, int64_t Dh, float theta) {
34
+ cos_h.resize(L * Dh);
35
+ sin_h.resize(L * Dh);
36
+ int64_t half = Dh / 2;
37
+ for (int64_t s = 0; s < L; s++) {
38
+ for (int64_t d = 0; d < Dh; d++) {
39
+ int64_t pair = (d < half) ? d : (d - half);
40
+ float theta_pair = 1.0f / std::pow(theta, (2.0f * pair) / Dh);
41
+ float angle = (float)(p0 + s) * theta_pair;
42
+ cos_h[s * Dh + d] = _engine_f2bf16(std::cos(angle));
43
+ sin_h[s * Dh + d] = _engine_f2bf16(std::sin(angle));
44
+ }
45
+ }
46
+ }
47
+
48
+ // Precomputed RoPE cos/sin table: BF16 [max_seq, Dh]. One-time cost per runtime.
49
+ struct RopeCache {
50
+ DeviceBuffer cos; // [max_seq, Dh] BF16
51
+ DeviceBuffer sin; // [max_seq, Dh] BF16
52
+ int64_t max_seq = 0;
53
+ int64_t head_dim = 0;
54
+ float theta = 0.0f;
55
+ };
56
+
57
+ inline bool rope_cache_build(RopeCache& rc, int64_t max_seq, int64_t head_dim, float theta) {
58
+ std::vector<uint16_t> cos_h, sin_h;
59
+ fill_cos_sin_hf(cos_h, sin_h, /*p0=*/0, max_seq, head_dim, theta);
60
+ rc.cos.alloc(max_seq * head_dim * 2);
61
+ rc.sin.alloc(max_seq * head_dim * 2);
62
+ ACL_CHECK(aclrtMemcpy(rc.cos.get(), cos_h.size() * 2, cos_h.data(), cos_h.size() * 2, ACL_MEMCPY_HOST_TO_DEVICE));
63
+ ACL_CHECK(aclrtMemcpy(rc.sin.get(), sin_h.size() * 2, sin_h.data(), sin_h.size() * 2, ACL_MEMCPY_HOST_TO_DEVICE));
64
+ rc.max_seq = max_seq; rc.head_dim = head_dim; rc.theta = theta;
65
+ return true;
66
+ }
67
+
68
+ // Attention forward for a single layer.
69
+ //
70
+ // x_in [S, D] (hidden state, pre input_layernorm)
71
+ // x_out [S, D] (attention output — NOT residual-added)
72
+ //
73
+ // K cache / V cache are contiguous [MAX_LEN, KV_DIM] BF16 buffers. This call writes new
74
+ // positions at [past_len, past_len+S) and then runs FIAS over [0, past_len+S).
75
+ //
76
+ // Scratch requirements:
77
+ // q_scratch: S * Q_DIM * 2 bytes
78
+ // k_scratch: S * KV_DIM * 2 bytes
79
+ // v_scratch: S * KV_DIM * 2 bytes
80
+ // xn_scratch: S * D * 2 bytes
81
+ // rstd_scratch: S * 4 bytes (RmsNorm rstd output)
82
+ // rope_scratch: S * Hq * Dh * 2 bytes
83
+ //
84
+ // mask: [1, 1, 2048, 2048] bool for prefill; ignored (pass nullptr) for decode.
85
+ inline void attention_forward(
86
+ aclrtStream stream,
87
+ const ModelConfig& cfg,
88
+ LayerAttnWeights& w,
89
+ void* x_in, // [S, D] BF16
90
+ int64_t S,
91
+ int64_t past_len, // prior KV positions
92
+ void* k_cache, void* v_cache, int64_t max_len,
93
+ aclTensor* mask_tensor, // may be nullptr for decode
94
+ void* q_scratch, void* k_scratch, void* v_scratch,
95
+ void* xn_scratch, void* rstd_scratch, void* rope_scratch,
96
+ void* attn_out_scratch, // S * Q_DIM * 2 bytes (FIAS output before o_proj)
97
+ void* x_out, // [S, D] BF16
98
+ HcclCtx* hccl_ctx = nullptr, // if tp_size > 1, AllReduce x_out after o_proj
99
+ const RopeCache* rope_cache = nullptr, // if provided, use cached cos/sin table; avoids per-call H2D
100
+ int64_t sparse_mode = -1 // -1=auto (3 for prefill, 0 for decode); explicit 0/3 overrides
101
+ ) {
102
+ const int64_t D = cfg.hidden_size;
103
+ const int64_t Hq = cfg.n_heads_per_rank;
104
+ const int64_t Hkv = cfg.n_kv_heads_per_rank;
105
+ const int64_t Dh = cfg.head_dim;
106
+ const int64_t Q_DIM = Hq * Dh;
107
+ const int64_t KV_DIM = Hkv * Dh;
108
+ const double scale = 1.0 / std::sqrt((double)Dh);
109
+ const double eps = cfg.rms_norm_eps;
110
+ const float theta = cfg.rope_theta;
111
+
112
+ // 1. Input layernorm: xn = rmsnorm(x_in, input_layernorm_weight)
113
+ auto t_x = make_contig_tensor(x_in, ACL_BF16, {S, D});
114
+ auto t_xn = make_contig_tensor(xn_scratch, ACL_BF16, {S, D});
115
+ auto t_lnw = make_contig_tensor(w.input_layernorm.get(), ACL_BF16, {D});
116
+ auto t_rstd = make_contig_tensor(rstd_scratch, ACL_FLOAT, {S});
117
+ rms_norm(stream, t_x.get(), t_lnw.get(), eps, t_xn.get(), t_rstd.get());
118
+
119
+ // 2. Q/K/V projection
120
+ auto t_q = make_contig_tensor(q_scratch, ACL_BF16, {S, Q_DIM});
121
+ auto t_k = make_contig_tensor(k_scratch, ACL_BF16, {S, KV_DIM});
122
+ auto t_v = make_contig_tensor(v_scratch, ACL_BF16, {S, KV_DIM});
123
+ linear_hf(stream, t_xn.get(), w.q_proj.get(), ACL_BF16, Q_DIM, D, t_q.get());
124
+ linear_hf(stream, t_xn.get(), w.k_proj.get(), ACL_BF16, KV_DIM, D, t_k.get());
125
+ linear_hf(stream, t_xn.get(), w.v_proj.get(), ACL_BF16, KV_DIM, D, t_v.get());
126
+
127
+ // 3. Per-head q_norm, k_norm
128
+ auto t_q_4d = make_contig_tensor(q_scratch, ACL_BF16, {1, S, Hq, Dh});
129
+ auto t_k_4d = make_contig_tensor(k_scratch, ACL_BF16, {1, S, Hkv, Dh});
130
+ auto t_qn_w = make_contig_tensor(w.q_norm.get(), ACL_BF16, {Dh});
131
+ auto t_kn_w = make_contig_tensor(w.k_norm.get(), ACL_BF16, {Dh});
132
+ // reuse rstd_scratch split or allocate? reuse xn_scratch's first S*Hq*4 bytes.
133
+ // Simpler: require rstd_scratch to have max(S, S*max(Hq,Hkv)) * 4 bytes.
134
+ // For single-rank attention tests we pass enough.
135
+ auto t_rstd_q = make_contig_tensor(rstd_scratch, ACL_FLOAT, {1, S, Hq});
136
+ auto t_rstd_k = make_contig_tensor(rstd_scratch, ACL_FLOAT, {1, S, Hkv});
137
+ rms_norm(stream, t_q_4d.get(), t_qn_w.get(), eps, t_q_4d.get(), t_rstd_q.get());
138
+ rms_norm(stream, t_k_4d.get(), t_kn_w.get(), eps, t_k_4d.get(), t_rstd_k.get());
139
+
140
+ // 4. RoPE: positions [past_len, past_len + S). Fused aclnnApplyRotaryPosEmbV2 is 1 op
141
+ // vs 8-op manual version — saves ~7 kernel launches/layer × 94 layers = 658/token.
142
+ if (rope_cache && rope_cache->cos.get() && past_len + S <= rope_cache->max_seq) {
143
+ void* cos_ptr = (char*)rope_cache->cos.get() + past_len * Dh * 2;
144
+ void* sin_ptr = (char*)rope_cache->sin.get() + past_len * Dh * 2;
145
+ apply_rope_fused(stream, q_scratch, 1, S, Hq, Dh, k_scratch, Hkv, cos_ptr, sin_ptr);
146
+ } else {
147
+ std::vector<uint16_t> cos_h, sin_h;
148
+ fill_cos_sin_hf(cos_h, sin_h, past_len, S, Dh, theta);
149
+ DeviceBuffer cos_dev(S * Dh * 2), sin_dev(S * Dh * 2);
150
+ ACL_CHECK(aclrtMemcpy(cos_dev.get(), S*Dh*2, cos_h.data(), S*Dh*2, ACL_MEMCPY_HOST_TO_DEVICE));
151
+ ACL_CHECK(aclrtMemcpy(sin_dev.get(), S*Dh*2, sin_h.data(), S*Dh*2, ACL_MEMCPY_HOST_TO_DEVICE));
152
+ apply_rope_manual(stream, q_scratch, 1, S, Hq, Dh, k_scratch, Hkv,
153
+ cos_dev.get(), sin_dev.get(), rope_scratch);
154
+ // Local DeviceBuffers would be freed on return while async kernels still read them.
155
+ ACL_CHECK(aclrtSynchronizeStream(stream));
156
+ }
157
+
158
+ // 5. Append K, V to cache at [past_len, past_len + S)
159
+ ACL_CHECK(aclrtMemcpyAsync((char*)k_cache + past_len * KV_DIM * 2, S * KV_DIM * 2,
160
+ k_scratch, S * KV_DIM * 2, ACL_MEMCPY_DEVICE_TO_DEVICE, stream));
161
+ ACL_CHECK(aclrtMemcpyAsync((char*)v_cache + past_len * KV_DIM * 2, S * KV_DIM * 2,
162
+ v_scratch, S * KV_DIM * 2, ACL_MEMCPY_DEVICE_TO_DEVICE, stream));
163
+
164
+ // 6. FIAS: q [1, S, Q_DIM], k/v [1, kv_len, KV_DIM] from cache
165
+ int64_t kv_len = past_len + S;
166
+ auto t_q_bsh = make_contig_tensor(q_scratch, ACL_BF16, {1, S, Q_DIM});
167
+ auto t_k_bsh = make_contig_tensor(k_cache, ACL_BF16, {1, kv_len, KV_DIM});
168
+ auto t_v_bsh = make_contig_tensor(v_cache, ACL_BF16, {1, kv_len, KV_DIM});
169
+ // FIAS writes to a separate buffer (attn_out_scratch) — aliasing q→out is unsafe.
170
+ auto t_attn_out_bsh = make_contig_tensor(attn_out_scratch, ACL_BF16, {1, S, Q_DIM});
171
+ // sparse_mode selection:
172
+ // 3 = left-top causal (prefill, q.S == kv.S with 2048 mask)
173
+ // 0 = user mask (decode with cache, batch verify)
174
+ // -1 (sentinel) = auto: 3 if mask given & past_len==0 & S>1 (prefill), else 0
175
+ int64_t sparse = sparse_mode;
176
+ if (sparse < 0) {
177
+ sparse = (mask_tensor != nullptr && past_len == 0 && S > 1) ? 3 : 0;
178
+ }
179
+ fused_infer_attention_score(
180
+ stream, t_q_bsh.get(), t_k_bsh.get(), t_v_bsh.get(),
181
+ mask_tensor, {S}, {kv_len},
182
+ Hq, Hkv, scale, sparse, t_attn_out_bsh.get());
183
+
184
+ // 7. O projection: y = attn_out @ o_proj.T → [S, D]
185
+ auto t_attn_2d = make_contig_tensor(attn_out_scratch, ACL_BF16, {S, Q_DIM});
186
+ auto t_out = make_contig_tensor(x_out, ACL_BF16, {S, D});
187
+ linear_hf(stream, t_attn_2d.get(), w.o_proj.get(), ACL_BF16, D, Q_DIM, t_out.get());
188
+
189
+ // 8. TP AllReduce on x_out (row-parallel o_proj → SUM across ranks)
190
+ if (hccl_ctx && hccl_ctx->tp_size > 1) {
191
+ hccl_allreduce_bf16(*hccl_ctx, x_out, S * D, stream);
192
+ }
193
+ }
194
+
195
+ // MoE forward for a single layer. Residual NOT applied here.
196
+ //
197
+ // x_in [S, D] (hidden state, pre post_attention_layernorm)
198
+ // x_out [S, D] (MoE output)
199
+ //
200
+ // Scratch:
201
+ // xn_scratch: S * D * 2
202
+ // rstd_scratch: S * 4
203
+ // logits_scratch: S * E * 2
204
+ // topk_w_scratch: S * K * 2
205
+ // topk_idx_scratch: S * K * 4
206
+ // row_idx_scratch: S * K * 4 (gating output unused)
207
+ // expanded_x_scratch: TOTAL * D * 2
208
+ // expanded_ri_scratch:TOTAL * 4
209
+ // tpe_scratch: E * 8
210
+ // fwd_dev: TOTAL * 8
211
+ // gate_out_scratch: TOTAL * I * 2
212
+ // up_out_scratch: TOTAL * I * 2
213
+ // down_out_scratch: TOTAL * D * 2
214
+ // packed_scratch: TOTAL * D * 2
215
+ // weighted_scratch: S * K * D * 2
216
+ //
217
+ // where TOTAL = S * K, I = cfg.i_per_rank, E = cfg.num_experts, K = cfg.num_experts_per_tok.
218
+ //
219
+ // IMPORTANT: post_attention_layernorm weight in `attn_w` (not in LayerMoEWeights).
220
+ inline void moe_forward(
221
+ aclrtStream stream,
222
+ const ModelConfig& cfg,
223
+ LayerAttnWeights& attn_w, // for post_attention_layernorm
224
+ LayerMoEWeights& w,
225
+ void* x_in, int64_t S,
226
+ void* xn_scratch, void* rstd_scratch,
227
+ void* logits_scratch,
228
+ void* topk_w_scratch, void* topk_idx_scratch, void* row_idx_scratch,
229
+ void* expanded_x_scratch, void* expanded_ri_scratch, void* tpe_scratch,
230
+ void* fwd_scratch,
231
+ void* gate_out_scratch, void* up_out_scratch, void* down_out_scratch,
232
+ void* packed_scratch, void* weighted_scratch,
233
+ void* x_out,
234
+ HcclCtx* hccl_ctx = nullptr, // if tp_size > 1, AllReduce after reduce_sum
235
+ void* norm_sum_scratch = nullptr // S * 2 bytes — persistent buffer for topk_w normalize
236
+ ) {
237
+ const int64_t D = cfg.hidden_size;
238
+ const int64_t I = cfg.i_per_rank;
239
+ const int64_t E = cfg.num_experts;
240
+ const int64_t K = cfg.num_experts_per_tok;
241
+ const double eps = cfg.rms_norm_eps;
242
+ const int64_t TOTAL = S * K;
243
+
244
+ // 1. post_attention_layernorm
245
+ auto t_x = make_contig_tensor(x_in, ACL_BF16, {S, D});
246
+ auto t_xn = make_contig_tensor(xn_scratch, ACL_BF16, {S, D});
247
+ auto t_lnw = make_contig_tensor(attn_w.post_attention_layernorm.get(), ACL_BF16, {D});
248
+ auto t_rstd = make_contig_tensor(rstd_scratch, ACL_FLOAT, {S});
249
+ rms_norm(stream, t_x.get(), t_lnw.get(), eps, t_xn.get(), t_rstd.get());
250
+
251
+ // 2. Router linear: logits = xn @ router.T → [S, E]
252
+ auto t_logits = make_contig_tensor(logits_scratch, ACL_BF16, {S, E});
253
+ linear_hf(stream, t_xn.get(), w.router.get(), ACL_BF16, E, D, t_logits.get());
254
+
255
+ // 3. TopK softmax
256
+ auto t_topk_w = make_contig_tensor(topk_w_scratch, ACL_BF16, {S, K});
257
+ auto t_topk_idx = make_contig_tensor(topk_idx_scratch, ACL_INT32, {S, K});
258
+ auto t_row_idx = make_contig_tensor(row_idx_scratch, ACL_INT32, {S, K});
259
+ moe_gating_topk_softmax(stream, t_logits.get(), K, t_topk_w.get(), t_topk_idx.get(), t_row_idx.get());
260
+
261
+ // 4. Device-side normalize topk weights (Qwen3 norm_topk_prob=true).
262
+ // sum = reduce_sum(topk_w, dim=-1, keepdim=true) # [S, 1] F32 in rstd_scratch
263
+ // sum += 1e-20
264
+ // sum_bf16 = cast(sum, BF16) # [S, 1] in norm_sum_scratch (caller-owned)
265
+ // topk_w /= sum_bf16 # broadcast divide
266
+ // No per-layer syncs — all scratch buffers persist across layers.
267
+ if (norm_sum_scratch) {
268
+ auto t_sum = make_contig_tensor(rstd_scratch, ACL_FLOAT, {S, 1});
269
+ auto t_sum_bf16 = make_contig_tensor(norm_sum_scratch, ACL_BF16, {S, 1});
270
+ reduce_sum(stream, t_topk_w.get(), {-1}, /*keep_dims=*/true, ACL_FLOAT, t_sum.get());
271
+ inplace_adds(stream, t_sum.get(), 1e-20);
272
+ cast(stream, t_sum.get(), ACL_BF16, t_sum_bf16.get());
273
+ div_tensor(stream, t_topk_w.get(), t_sum_bf16.get(), t_topk_w.get());
274
+ } else {
275
+ // Fallback: host-side normalize (for callers that didn't provide scratch).
276
+ ACL_CHECK(aclrtSynchronizeStream(stream));
277
+ std::vector<uint16_t> h_tw(S * K);
278
+ ACL_CHECK(aclrtMemcpy(h_tw.data(), S*K*2, topk_w_scratch, S*K*2, ACL_MEMCPY_DEVICE_TO_HOST));
279
+ for (int s = 0; s < S; s++) {
280
+ float sum = 0;
281
+ for (int k = 0; k < K; k++) {
282
+ uint32_t u = (uint32_t)h_tw[s*K + k] << 16;
283
+ float v; std::memcpy(&v, &u, 4);
284
+ sum += v;
285
+ }
286
+ sum += 1e-20f;
287
+ for (int k = 0; k < K; k++) {
288
+ uint32_t u = (uint32_t)h_tw[s*K + k] << 16;
289
+ float v; std::memcpy(&v, &u, 4);
290
+ v /= sum;
291
+ std::memcpy(&u, &v, 4);
292
+ h_tw[s*K + k] = (uint16_t)((u + 0x7FFF + ((u >> 16) & 1)) >> 16);
293
+ }
294
+ }
295
+ ACL_CHECK(aclrtMemcpy(topk_w_scratch, S*K*2, h_tw.data(), S*K*2, ACL_MEMCPY_HOST_TO_DEVICE));
296
+ }
297
+
298
+ // 5. MoE init routing
299
+ auto t_ex_x = make_contig_tensor(expanded_x_scratch, ACL_BF16, {TOTAL, D});
300
+ auto t_ex_ri = make_contig_tensor(expanded_ri_scratch, ACL_INT32, {TOTAL});
301
+ auto t_tpe = make_contig_tensor(tpe_scratch, ACL_INT64, {E});
302
+ moe_init_routing_v3(stream, t_xn.get(), t_topk_idx.get(),
303
+ E, TOTAL, t_ex_x.get(), t_ex_ri.get(), t_tpe.get());
304
+
305
+ // 6. GMM gate + up
306
+ auto t_gate_out = make_contig_tensor(gate_out_scratch, ACL_BF16, {TOTAL, I});
307
+ auto t_up_out = make_contig_tensor(up_out_scratch, ACL_BF16, {TOTAL, I});
308
+ auto t_w_gate = make_contig_tensor(w.gate_exps.get(), ACL_BF16, {E, D, I});
309
+ auto t_w_up = make_contig_tensor(w.up_exps.get(), ACL_BF16, {E, D, I});
310
+ grouped_matmul_v4(stream, t_ex_x.get(), t_w_gate.get(), t_tpe.get(), t_gate_out.get(), 1);
311
+ grouped_matmul_v4(stream, t_ex_x.get(), t_w_up.get(), t_tpe.get(), t_up_out.get(), 1);
312
+
313
+ // 7. SwiGLU: gate_out = silu(gate_out) * up_out
314
+ silu(stream, t_gate_out.get(), t_gate_out.get());
315
+ mul(stream, t_gate_out.get(), t_up_out.get(), t_gate_out.get());
316
+
317
+ // 8. GMM down
318
+ auto t_down_out = make_contig_tensor(down_out_scratch, ACL_BF16, {TOTAL, D});
319
+ auto t_w_down = make_contig_tensor(w.down_exps.get(), ACL_BF16, {E, I, D});
320
+ grouped_matmul_v4(stream, t_gate_out.get(), t_w_down.get(), t_tpe.get(), t_down_out.get(), 1);
321
+
322
+ // 9. Device-side finalize: build forward perm via two consecutive argsorts on topk_idx.
323
+ // No host sync — safe for graph capture.
324
+ // inv_fwd = argsort(topk_idx.flat) // each (n,k) → sorted position (primary key: expert_id)
325
+ // fwd = argsort(inv_fwd) // inverse perm — what IndexSelect needs
326
+ // Stability: aclnnArgsort preserves input order for equal keys; flat index = n*K + k orders
327
+ // ties by n-then-k, matching our previous manual sort convention.
328
+ //
329
+ // Scratch for inv_fwd: reuse first TOTAL*8 bytes of weighted_scratch (gets overwritten
330
+ // by the subsequent mul op, so aliasing is safe).
331
+ {
332
+ auto t_topk_idx_flat = make_contig_tensor(topk_idx_scratch, ACL_INT32, {TOTAL});
333
+ auto t_inv_fwd = make_contig_tensor(weighted_scratch, ACL_INT64, {TOTAL});
334
+ auto t_fwd_64 = make_contig_tensor(fwd_scratch, ACL_INT64, {TOTAL});
335
+ argsort(stream, t_topk_idx_flat.get(), /*dim=*/0, /*descending=*/false, t_inv_fwd.get());
336
+ argsort(stream, t_inv_fwd.get(), /*dim=*/0, /*descending=*/false, t_fwd_64.get());
337
+ }
338
+ auto t_fwd = make_contig_tensor(fwd_scratch, ACL_INT64, {TOTAL});
339
+ auto t_packed = make_contig_tensor(packed_scratch, ACL_BF16, {TOTAL, D});
340
+ index_select(stream, t_down_out.get(), 0, t_fwd.get(), t_packed.get());
341
+
342
+ auto t_packed_3d = make_contig_tensor(packed_scratch, ACL_BF16, {S, K, D});
343
+ auto t_topk_w_3d = make_contig_tensor(topk_w_scratch, ACL_BF16, {S, K, 1});
344
+ auto t_weighted = make_contig_tensor(weighted_scratch, ACL_BF16, {S, K, D});
345
+ mul(stream, t_packed_3d.get(), t_topk_w_3d.get(), t_weighted.get());
346
+
347
+ auto t_out = make_contig_tensor(x_out, ACL_BF16, {S, D});
348
+ reduce_sum(stream, t_weighted.get(), {1}, false, ACL_BF16, t_out.get());
349
+
350
+ // TP AllReduce on MoE output (column-parallel experts → SUM partial intermediate outputs)
351
+ if (hccl_ctx && hccl_ctx->tp_size > 1) {
352
+ hccl_allreduce_bf16(*hccl_ctx, x_out, S * D, stream);
353
+ }
354
+ }
include/hccl_comm.h ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // hccl_comm.h — minimal HCCL wrapper for TP=N AllReduce.
2
+ //
3
+ // Multi-process mode (each rank is a separate process, device 0 each):
4
+ // - Rank 0 calls HcclGetRootInfo, writes to /tmp/hccl_root_info.bin
5
+ // - Rank 1..N-1 wait for that file, read it
6
+ // - All ranks call HcclCommInitRootInfo → shared HcclComm
7
+ // - allreduce() does in-place HcclAllReduce with SUM op
8
+ //
9
+ // Launcher sets HCCL_WHITELIST_DISABLE=1, ASCEND_RT_VISIBLE_DEVICES=<rank>, etc.
10
+ #pragma once
11
+ #include <hccl/hccl.h>
12
+ #include <hccl/hccl_types.h>
13
+ #include <acl/acl.h>
14
+
15
+ #include <chrono>
16
+ #include <cstdio>
17
+ #include <cstring>
18
+ #include <string>
19
+ #include <thread>
20
+
21
+ #define HCCL_ROOT_INFO_PATH "/tmp/hccl_root_info.bin"
22
+
23
+ struct HcclCtx {
24
+ HcclComm comm = nullptr;
25
+ int tp_size = 1;
26
+ int tp_rank = 0;
27
+ bool initialized = false;
28
+ };
29
+
30
+ inline bool hccl_init(HcclCtx& ctx, int tp_size, int tp_rank) {
31
+ if (tp_size <= 1) { ctx.tp_size = 1; ctx.tp_rank = 0; ctx.initialized = true; return true; }
32
+ ctx.tp_size = tp_size;
33
+ ctx.tp_rank = tp_rank;
34
+
35
+ HcclRootInfo rootInfo;
36
+ std::memset(&rootInfo, 0, sizeof(rootInfo));
37
+
38
+ if (tp_rank == 0) {
39
+ if (HcclGetRootInfo(&rootInfo) != HCCL_SUCCESS) {
40
+ fprintf(stderr, "[HCCL] HcclGetRootInfo failed\n"); return false;
41
+ }
42
+ FILE* f = fopen(HCCL_ROOT_INFO_PATH, "wb");
43
+ if (!f) { fprintf(stderr, "[HCCL] cannot write %s\n", HCCL_ROOT_INFO_PATH); return false; }
44
+ fwrite(&rootInfo, sizeof(rootInfo), 1, f);
45
+ fclose(f);
46
+ } else {
47
+ bool found = false;
48
+ for (int r = 0; r < 600; r++) { // 60s timeout
49
+ FILE* f = fopen(HCCL_ROOT_INFO_PATH, "rb");
50
+ if (f) {
51
+ size_t rd = fread(&rootInfo, 1, sizeof(rootInfo), f);
52
+ fclose(f);
53
+ if (rd == sizeof(rootInfo)) { found = true; break; }
54
+ }
55
+ std::this_thread::sleep_for(std::chrono::milliseconds(100));
56
+ }
57
+ if (!found) { fprintf(stderr, "[HCCL] rank %d timeout waiting for root info\n", tp_rank); return false; }
58
+ }
59
+
60
+ HcclResult r = HcclCommInitRootInfo((uint32_t)tp_size, &rootInfo, (uint32_t)tp_rank, &ctx.comm);
61
+ if (r != HCCL_SUCCESS) {
62
+ fprintf(stderr, "[HCCL] HcclCommInitRootInfo failed: %d (rank=%d)\n", (int)r, tp_rank);
63
+ return false;
64
+ }
65
+ ctx.initialized = true;
66
+ fprintf(stderr, "[HCCL] rank %d/%d comm OK\n", tp_rank, tp_size);
67
+ return true;
68
+ }
69
+
70
+ // In-place AllReduce SUM on BF16 tensor. dtype = HCCL_DATA_TYPE_BFP16.
71
+ inline bool hccl_allreduce_bf16(const HcclCtx& ctx, void* data, int64_t count, aclrtStream stream) {
72
+ if (!ctx.initialized) return false;
73
+ if (ctx.tp_size <= 1) return true; // no-op
74
+
75
+ HcclResult r = HcclAllReduce(data, data, (uint64_t)count,
76
+ HCCL_DATA_TYPE_BFP16, HCCL_REDUCE_SUM,
77
+ ctx.comm, stream);
78
+ if (r != HCCL_SUCCESS) {
79
+ fprintf(stderr, "[HCCL] AllReduce failed: %d\n", (int)r);
80
+ return false;
81
+ }
82
+ return true;
83
+ }
84
+
85
+ // Broadcast buffer from root (rank 0) to all ranks. Used to share prompt tokens across ranks.
86
+ // `data_dev` must be device memory. dtype generic (e.g., HCCL_DATA_TYPE_INT32).
87
+ inline bool hccl_broadcast(const HcclCtx& ctx, void* data_dev, int64_t count,
88
+ HcclDataType dtype, uint32_t root, aclrtStream stream) {
89
+ if (!ctx.initialized) return false;
90
+ if (ctx.tp_size <= 1) return true;
91
+
92
+ HcclResult r = HcclBroadcast(data_dev, (uint64_t)count, dtype, root, ctx.comm, stream);
93
+ if (r != HCCL_SUCCESS) {
94
+ fprintf(stderr, "[HCCL] Broadcast failed: %d\n", (int)r);
95
+ return false;
96
+ }
97
+ return true;
98
+ }
99
+
100
+ inline void hccl_shutdown(HcclCtx& ctx) {
101
+ if (ctx.comm) {
102
+ HcclCommDestroy(ctx.comm);
103
+ ctx.comm = nullptr;
104
+ }
105
+ ctx.initialized = false;
106
+ }
include/model_config.h ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // model_config.h — Qwen3 hparams loaded from HF config.json, plus TP-derived per-rank sizes.
2
+ #pragma once
3
+ #include <cstdint>
4
+ #include <string>
5
+
6
+ struct ModelConfig {
7
+ // ---- Raw hparams from config.json ----
8
+ int64_t vocab_size = 0;
9
+ int64_t hidden_size = 0; // D
10
+ int64_t intermediate_size = 0; // dense FFN (not used for MoE layers; kept for completeness)
11
+ int64_t moe_intermediate_size = 0; // I per expert
12
+ int64_t num_hidden_layers = 0; // = 94 for Qwen3-235B
13
+ int64_t num_attention_heads = 0; // = 64
14
+ int64_t num_key_value_heads = 0; // = 4 (GQA)
15
+ int64_t head_dim = 0; // = 128
16
+ int64_t num_experts = 0; // = 128
17
+ int64_t num_experts_per_tok = 0; // top_k = 8
18
+ int64_t max_position_embeddings = 0;
19
+ float rope_theta = 0.0f;
20
+ float rms_norm_eps = 1e-6f;
21
+ bool norm_topk_prob = true;
22
+ bool tie_word_embeddings = false;
23
+ int64_t bos_token_id = 0;
24
+ int64_t eos_token_id = 0;
25
+
26
+ // ---- TP configuration ----
27
+ int tp_size = 1;
28
+ int tp_rank = 0;
29
+
30
+ // ---- Derived per-rank sizes ----
31
+ // Attention Q: split along num_heads (head-parallel)
32
+ // n_heads_per_rank = num_attention_heads / tp_size
33
+ // q_dim_per_rank = n_heads_per_rank * head_dim
34
+ int64_t n_heads_per_rank = 0;
35
+ int64_t q_dim_per_rank = 0;
36
+
37
+ // Attention KV: GQA with num_kv_heads < tp_size needs special handling.
38
+ // For Qwen3-235B: num_kv_heads = 4, tp_size = 16 → each KV head is replicated 4× across ranks.
39
+ // Simple scheme: each rank computes ALL kv heads (small, 4 × 128 = 512 features)
40
+ // then slices attention output for its own q heads.
41
+ // Alternative: split KV heads if tp_size <= num_kv_heads.
42
+ int64_t n_kv_heads_per_rank = 0;
43
+ int64_t kv_dim_per_rank = 0;
44
+
45
+ // MoE: intermediate dim split. Each rank holds 1/tp_size of experts' intermediate_size.
46
+ // i_per_rank = moe_intermediate_size / tp_size
47
+ int64_t i_per_rank = 0;
48
+
49
+ bool load_from_json(const std::string& path);
50
+ void compute_derived(int tp_size, int tp_rank);
51
+ std::string describe() const;
52
+ };
include/rope.h ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // rope.h — Manual HF-style RoPE using basic aclnn ops.
2
+ //
3
+ // Formula: q_out = q * cos + rotate_half(q) * sin
4
+ // where rotate_half(q) = concat(-q[..., d/2:], q[..., :d/2], dim=-1)
5
+ //
6
+ // Tensor layout: q/k [B, S, N, Dh] BF16, cos/sin [1, S, Dh] BF16
7
+ // (cos/sin are broadcast across B and N dims)
8
+ //
9
+ #pragma once
10
+ #include "acl_common.h"
11
+ #include "aclnn_ops.h"
12
+ #include <aclnnop/aclnn_apply_rotary_pos_emb_v2.h>
13
+
14
+ // Fused RoPE via aclnnApplyRotaryPosEmbV2 — replaces the 8-op manual version with a single
15
+ // op, saving ~7 launches per layer × 94 layers = ~658 launches/token. Validated on 910 initial:
16
+ // layout=1 + rotaryMode="half" matches HF rotate_half semantics (rel=1.24e-3 vs manual).
17
+ //
18
+ // q_data: [B, S, Nq, Dh] BF16 (modified in place)
19
+ // k_data: [B, S, Nk, Dh] BF16 (modified in place)
20
+ // cos_data / sin_data: [1, S, 1, Dh] BF16 (single contiguous buffer slice from RopeCache)
21
+ inline void apply_rope_fused(aclrtStream stream,
22
+ void* q_data, int64_t B, int64_t S, int64_t Nq, int64_t Dh,
23
+ void* k_data, int64_t Nk,
24
+ void* cos_data, void* sin_data) {
25
+ const aclDataType dt = ACL_BF16;
26
+ auto t_q = make_contig_tensor(q_data, dt, {B, S, Nq, Dh});
27
+ auto t_k = make_contig_tensor(k_data, dt, {B, S, Nk, Dh});
28
+ auto t_cos = make_contig_tensor(cos_data, dt, {1, S, 1, Dh});
29
+ auto t_sin = make_contig_tensor(sin_data, dt, {1, S, 1, Dh});
30
+ uint64_t ws = 0; aclOpExecutor* exec = nullptr;
31
+ char mode[] = "half";
32
+ ACLNN_CHECK(aclnnApplyRotaryPosEmbV2GetWorkspaceSize(
33
+ t_q.get(), t_k.get(), t_cos.get(), t_sin.get(),
34
+ /*layout=*/1, mode, &ws, &exec));
35
+ void* wp = (ws > 0) ? _lca_pool().alloc(ws) : nullptr;
36
+ ACLNN_CHECK(aclnnApplyRotaryPosEmbV2(wp, ws, exec, stream));
37
+ }
38
+
39
+ // Apply RoPE in-place to q and k.
40
+ // q_data: pointer to [B, S, Nq, Dh] BF16 (modified in place)
41
+ // k_data: pointer to [B, S, Nk, Dh] BF16 (modified in place)
42
+ // cos_data, sin_data: [1, S, Dh] BF16
43
+ // scratch_data: pointer to contiguous [B, S, max(Nq,Nk), Dh] BF16 scratch buffer for rotate_half
44
+ inline void apply_rope_manual(aclrtStream stream,
45
+ void* q_data, int64_t B, int64_t S, int64_t Nq, int64_t Dh,
46
+ void* k_data, int64_t Nk,
47
+ void* cos_data, void* sin_data,
48
+ void* scratch_data) {
49
+ const aclDataType dt = ACL_BF16;
50
+ const size_t elem = 2;
51
+ const int64_t halfDh = Dh / 2;
52
+
53
+ auto process = [&](void* x_data, int64_t N) {
54
+ // Strides in elements (row-major [B, S, N, Dh]):
55
+ // stride = [S*N*Dh, N*Dh, Dh, 1]
56
+ const std::vector<int64_t> full_shape = {B, S, N, Dh};
57
+ const std::vector<int64_t> full_stride = {S*N*Dh, N*Dh, Dh, 1};
58
+ const std::vector<int64_t> half_shape = {B, S, N, halfDh};
59
+ const std::vector<int64_t> half_stride = full_stride; // same leading 3 strides
60
+
61
+ // View x as full
62
+ auto t_x = make_acl_tensor(x_data, dt, full_shape, full_stride);
63
+
64
+ // View of x left half and right half (shifted pointers, same layout, last dim half)
65
+ auto t_x_left = make_acl_tensor(x_data, dt, half_shape, half_stride);
66
+ auto t_x_right = make_acl_tensor((char*)x_data + halfDh*elem, dt, half_shape, half_stride);
67
+
68
+ // rotate_half buffer view (contiguous [B, S, N, Dh])
69
+ const std::vector<int64_t> rh_stride = {S*N*Dh, N*Dh, Dh, 1};
70
+ auto t_rh = make_acl_tensor(scratch_data, dt, full_shape, rh_stride);
71
+ auto t_rh_left = make_acl_tensor(scratch_data, dt, half_shape, rh_stride);
72
+ auto t_rh_right = make_acl_tensor((char*)scratch_data + halfDh*elem, dt, half_shape, rh_stride);
73
+
74
+ // rh[..., :Dh/2] = -x[..., Dh/2:]
75
+ neg(stream, t_x_right.get(), t_rh_left.get());
76
+ // rh[..., Dh/2:] = x[..., :Dh/2]
77
+ inplace_copy(stream, t_rh_right.get(), t_x_left.get());
78
+
79
+ // cos/sin views broadcastable to [B, S, N, Dh]
80
+ // Original storage: [1, S, Dh]. For broadcast, use shape [1, S, 1, Dh] with strides [0, Dh, 0, 1].
81
+ auto t_cos = make_acl_tensor(cos_data, dt, {1, S, 1, Dh}, {0, Dh, 0, 1});
82
+ auto t_sin = make_acl_tensor(sin_data, dt, {1, S, 1, Dh}, {0, Dh, 0, 1});
83
+
84
+ // q_rot = q * cos + rh * sin (use addcmul: q *= cos, then q += rh * sin)
85
+ // Compute tmp = q * cos (fresh buffer needed; use scratch_data is occupied by rh)
86
+ // Better: multiply x in place: x *= cos, then x += rh * sin
87
+ // aclnnMul with x as both in and out is inplace.
88
+ mul(stream, t_x.get(), t_cos.get(), t_x.get()); // x = x * cos
89
+ addcmul(stream, t_x.get(), t_rh.get(), t_sin.get(), 1); // x += 1 * (rh * sin)
90
+ };
91
+
92
+ process(q_data, Nq);
93
+ process(k_data, Nk);
94
+ }
include/runner.h ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // runner.h — multi-layer transformer Runner for Qwen3-235B-A22B.
2
+ //
3
+ // Owns: shared weights, per-layer attention + MoE weights, KV cache, scratch buffers.
4
+ // Provides: prefill(tokens) and decode(new_token) methods returning logits [vocab] on device.
5
+ //
6
+ // Memory budget at TP=1 for testing a SUBSET of layers (num_layers_to_load <= 94). Full 94-layer
7
+ // inference requires TP=16 where per-rank MoE fits ~28GB.
8
+ #pragma once
9
+ #include "acl_common.h"
10
+ #include "acl_runtime.h"
11
+ #include "aclnn_ops.h"
12
+ #include "device_weights.h"
13
+ #include "engine.h"
14
+ #include "hccl_comm.h"
15
+ #include "model_config.h"
16
+ #include "safetensors_loader.h"
17
+
18
+ #include <vector>
19
+
20
+ class Runner {
21
+ public:
22
+ Runner() = default;
23
+ ~Runner() = default;
24
+ Runner(const Runner&) = delete;
25
+ Runner& operator=(const Runner&) = delete;
26
+
27
+ // Initialize runtime, open safetensors, load shared weights. tp_size/tp_rank configure
28
+ // MoE + attention sharding. num_layers is how many transformer blocks to load (1..94).
29
+ // max_seq is the maximum sequence length (for KV cache allocation).
30
+ bool init(const std::string& model_dir, int tp_size, int tp_rank,
31
+ int num_layers_to_load, int64_t max_seq, int device_id = 0);
32
+
33
+ // Prefill: ingest S>=1 tokens, produces logits [vocab] for the LAST position. Populates KV
34
+ // cache starting at position 0. `hidden_out` optionally returns the final hidden state [S, D].
35
+ bool prefill(const int32_t* tokens, int64_t S, DeviceBuffer& logits_out);
36
+
37
+ // Decode: take 1 new token, produce logits [vocab] from the new position.
38
+ bool decode(int32_t token, DeviceBuffer& logits_out);
39
+
40
+ // Batched decode: take S tokens as "candidate verify batch" at positions [past_len..past_len+S),
41
+ // produce logits [S, vocab]. Uses causal-with-past mask (token i sees past+tokens[0..i]).
42
+ // Foundation for speculative decoding / PLD.
43
+ // tokens: [S] int32
44
+ // S: 1 .. 16
45
+ // all_logits_out: will hold S * vocab_size * 2 bytes BF16, row-major [S, V]
46
+ // Updates past_len by +S on success.
47
+ bool decode_batch(const int32_t* tokens, int64_t S, DeviceBuffer& all_logits_out);
48
+
49
+ // Warmup: run N dummy decode() calls (resetting cache) to pre-compile aclnn executors,
50
+ // warm HCCL collective buffers, and stabilize NPU thermals. Improves first-N-token latency
51
+ // by ~1 s (especially noticeable on short generations or REPL cold start).
52
+ // Call after init(); safe to call multiple times. Does NOT affect past_len.
53
+ void warmup(int iterations = 3);
54
+
55
+ // Accessors
56
+ const ModelConfig& cfg() const { return cfg_; }
57
+ aclrtStream stream() { return rt_.stream(); }
58
+ int64_t past_len() const { return past_len_; }
59
+ void reset_cache() { past_len_ = 0; }
60
+ // Rewind past_len by n. Used by speculative decoding to discard rejected draft tokens'
61
+ // KV cache entries (they'll be overwritten by subsequent writes).
62
+ void rewind_cache(int64_t n) { if (n > 0 && n <= past_len_) past_len_ -= n; }
63
+ HcclCtx& hccl_ctx() { return hccl_ctx_; }
64
+
65
+ // Profiling: set via LCA_PROFILE=1 env in main_cli. If enabled, decode() accumulates
66
+ // per-phase wall-clock ms into the timer accumulators below.
67
+ bool profile_enabled = false;
68
+ double t_embed_ms = 0, t_layers_ms = 0, t_final_ms = 0;
69
+ int64_t profile_calls = 0;
70
+ void print_profile_summary() const;
71
+
72
+ private:
73
+ // One-layer forward: x_in [S, D] → x_out [S, D] via attention + residual + MoE + residual.
74
+ // Uses this layer's KV cache starting at past_len; caller updates past_len after each call.
75
+ // batch_decode_mode: true for S>1 at past_len>0 (spec decoding) — uses custom causal mask
76
+ // with past instead of the 2048×2048 prefill mask.
77
+ void layer_forward_(int layer_idx, int64_t S, void* x_in, void* x_out,
78
+ bool batch_decode_mode = false);
79
+
80
+ // Build causal-with-past mask in batch_mask_dev_ for decode_batch at current past_len.
81
+ // Shape [1, 1, S, past_len+S] bool, mask[i, j] = 1 iff j > past_len+i.
82
+ void build_batch_decode_mask_(int64_t S);
83
+
84
+ // Final: final_norm + lm_head on last position → logits [vocab].
85
+ void final_logits_(void* hidden_last /*[1, D]*/, DeviceBuffer& logits_out);
86
+
87
+ // Batched final: final_norm + lm_head on [S, D] → logits [S, V].
88
+ void final_logits_batch_(void* hidden /*[S, D]*/, int64_t S, DeviceBuffer& logits_out);
89
+
90
+ AclRuntime rt_;
91
+ SafetensorsLoader st_;
92
+ ModelConfig cfg_;
93
+ HcclCtx hccl_ctx_;
94
+ int num_layers_ = 0;
95
+ int64_t max_seq_ = 0;
96
+
97
+ SharedWeights shared_;
98
+ std::vector<LayerAttnWeights> attn_;
99
+ std::vector<LayerMoEWeights> moe_;
100
+
101
+ // Per-layer KV cache
102
+ std::vector<DeviceBuffer> k_cache_;
103
+ std::vector<DeviceBuffer> v_cache_;
104
+
105
+ // Scratch (reallocated per-call sized by current S)
106
+ DeviceBuffer q_sc_, k_sc_, v_sc_, xn_sc_, rstd_sc_, rope_sc_, attn_fias_sc_, attn_out_sc_;
107
+ DeviceBuffer moe_xn_, moe_rstd_, moe_logits_;
108
+ DeviceBuffer moe_topk_w_, moe_topk_idx_, moe_row_idx_;
109
+ DeviceBuffer moe_ex_x_, moe_ex_ri_, moe_tpe_;
110
+ DeviceBuffer moe_fwd_;
111
+ DeviceBuffer moe_gate_, moe_up_, moe_down_;
112
+ DeviceBuffer moe_packed_, moe_weighted_, moe_out_;
113
+ DeviceBuffer moe_norm_sum_; // BF16 [S, 1] for on-device topk_w normalize
114
+ DeviceBuffer x_buf_a_, x_buf_b_; // ping-pong for residual chain
115
+
116
+ // Causal mask for prefill (2048 x 2048 bool); decode uses nullptr
117
+ DeviceBuffer prefill_mask_dev_;
118
+
119
+ // Batch decode mask: S_MAX × KV_MAX bool, where mask[i, j] = 1 (masked out) if
120
+ // j > past_len + i. Built on-demand per-call (past_len changes).
121
+ DeviceBuffer batch_mask_dev_;
122
+
123
+ // Pre-computed RoPE cos/sin table (sized for max_seq_)
124
+ RopeCache rope_cache_;
125
+
126
+ int64_t past_len_ = 0;
127
+ int64_t cur_S_capacity_ = 0; // scratch sized for this many tokens
128
+ };
include/safetensors_loader.h ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // safetensors_loader.h — lazy multi-shard safetensors reader.
2
+ //
3
+ // Usage:
4
+ // SafetensorsLoader loader;
5
+ // loader.open("/path/to/model_dir"); // parses index.json + all shard headers
6
+ // auto meta = loader.get("model.layers.0.self_attn.q_proj.weight");
7
+ // const void* host_ptr = loader.data_ptr(meta); // mmap-backed, host memory
8
+ // // copy to device: aclrtMemcpy(d_ptr, n, host_ptr, n, ACL_MEMCPY_HOST_TO_DEVICE);
9
+ //
10
+ // Files are mmap'd on first access and unmapped at destruction.
11
+ //
12
+ #pragma once
13
+ #include <cstdint>
14
+ #include <map>
15
+ #include <string>
16
+ #include <unordered_map>
17
+ #include <vector>
18
+
19
+ struct TensorMeta {
20
+ std::string name;
21
+ std::string dtype; // "BF16", "F16", "F32", "I32", "I64"
22
+ std::vector<int64_t> shape;
23
+ int shard_id = -1; // index into SafetensorsLoader::shards_
24
+ size_t offset = 0; // byte offset within shard (after 8B header_len + JSON header)
25
+ size_t nbytes = 0;
26
+ };
27
+
28
+ struct ShardFile {
29
+ std::string path;
30
+ int fd = -1;
31
+ void* mmap_ptr = nullptr;
32
+ size_t mmap_size = 0;
33
+ size_t data_base = 0; // byte offset to first tensor data within file
34
+ };
35
+
36
+ class SafetensorsLoader {
37
+ public:
38
+ SafetensorsLoader();
39
+ ~SafetensorsLoader();
40
+
41
+ // Opens a HuggingFace model directory. Returns false on failure.
42
+ // Expects: <dir>/model.safetensors.index.json + model-XXXXX-of-YYYYY.safetensors
43
+ bool open(const std::string& model_dir);
44
+
45
+ // Get tensor metadata. Returns nullptr if name not found.
46
+ const TensorMeta* get(const std::string& name) const;
47
+
48
+ // Return host pointer to tensor's raw bytes (mmap-backed). Null if not found or mmap failed.
49
+ const void* data_ptr(const TensorMeta& m);
50
+ const void* data_ptr(const std::string& name);
51
+
52
+ // Enumerate all tensor names (stable order = lexicographic).
53
+ std::vector<std::string> list_tensor_names() const;
54
+
55
+ // Stats
56
+ size_t tensor_count() const { return tensors_.size(); }
57
+ size_t shard_count() const { return shards_.size(); }
58
+ size_t total_bytes() const;
59
+
60
+ private:
61
+ bool parse_shard_header_(int shard_id);
62
+ bool mmap_shard_(int shard_id);
63
+
64
+ std::string model_dir_;
65
+ std::vector<ShardFile> shards_;
66
+ std::map<std::string, TensorMeta> tensors_; // ordered for determinism
67
+ };
68
+
69
+ // ---- Helpers ----
70
+
71
+ // Convert safetensors dtype string to element byte size.
72
+ inline size_t sdtype_size(const std::string& s) {
73
+ if (s == "F32" || s == "I32") return 4;
74
+ if (s == "F16" || s == "BF16" || s == "I16") return 2;
75
+ if (s == "F64" || s == "I64") return 8;
76
+ if (s == "I8" || s == "U8" || s == "BOOL") return 1;
77
+ return 0;
78
+ }
include/tokenizer.h ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // tokenizer.h — minimal Qwen3 tokenizer.
2
+ //
3
+ // M2-phase1: decode() is native C++ (simple vocab lookup). encode() is a Python subprocess
4
+ // (one-time cost at prompt setup). Native BPE encode is a future item.
5
+ //
6
+ #pragma once
7
+ #include <string>
8
+ #include <vector>
9
+ #include <cstdint>
10
+
11
+ class Tokenizer {
12
+ public:
13
+ bool load(const std::string& vocab_bin_path);
14
+
15
+ // Decode a single token id to UTF-8 string.
16
+ std::string decode(int token_id) const;
17
+
18
+ // Decode list of token ids to concatenated UTF-8 string.
19
+ std::string decode(const std::vector<int>& token_ids) const;
20
+
21
+ // Encode prompt to token ids. Uses a Python subprocess since Qwen3 needs proper BPE.
22
+ // The subprocess call takes ~200ms but is only invoked once per prompt.
23
+ std::vector<int> encode_via_python(const std::string& model_dir,
24
+ const std::string& prompt,
25
+ bool apply_chat_template = false) const;
26
+
27
+ // Encode a multi-turn conversation by applying the model's chat template. Each pair is
28
+ // (role, content) — typical roles: "system", "user", "assistant". Uses Python subprocess.
29
+ std::vector<int> encode_conversation_via_python(
30
+ const std::string& model_dir,
31
+ const std::vector<std::pair<std::string, std::string>>& conversation,
32
+ bool add_generation_prompt = true) const;
33
+
34
+ size_t size() const { return id_to_bytes_.size(); }
35
+
36
+ private:
37
+ std::vector<std::string> id_to_bytes_; // id -> raw utf-8 bytes
38
+ };
include/workspace_pool.h ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // workspace_pool.h — reusable aclnn workspace buffer pool.
2
+ //
3
+ // Problem: every aclnn op does `aclrtMalloc(workspace)` + `aclrtFree`. For decode at 94 layers
4
+ // × ~30 ops = 2820 mallocs/frees per token, this is significant overhead.
5
+ //
6
+ // Solution: pool of DeviceBuffers, grow-only. Pool returns a pointer >= requested size.
7
+ // Most ops reuse the SAME buffer since they don't overlap on-stream (serial execution).
8
+ //
9
+ // Thread safety: not thread-safe. One pool per Runner (one thread).
10
+ #pragma once
11
+ #include "acl_common.h"
12
+ #include <algorithm>
13
+ #include <vector>
14
+
15
+ class WorkspacePool {
16
+ public:
17
+ WorkspacePool() = default;
18
+ ~WorkspacePool() = default;
19
+ WorkspacePool(const WorkspacePool&) = delete;
20
+ WorkspacePool& operator=(const WorkspacePool&) = delete;
21
+
22
+ // Return a device pointer of at least `bytes`. Reuses the current buffer
23
+ // if it's big enough; otherwise grows by allocating a new one and
24
+ // **retaining old buffers** (async kernels may still be reading them —
25
+ // freeing too early would corrupt in-flight workspaces).
26
+ //
27
+ // Periodically call `reset_after_sync()` when the stream is idle to
28
+ // reclaim all-but-largest buffers and reset grow count.
29
+ void* alloc(size_t bytes) {
30
+ if (bytes == 0) return nullptr;
31
+ if (current_size_ < bytes) {
32
+ // Keep old buffer alive (don't free!) — aclnn kernels may still use it.
33
+ old_bufs_.push_back(std::move(buf_));
34
+ buf_.alloc(bytes);
35
+ current_size_ = bytes;
36
+ grow_count_++;
37
+ }
38
+ return buf_.get();
39
+ }
40
+
41
+ size_t current_size() const { return current_size_; }
42
+ size_t grow_count() const { return grow_count_; }
43
+ size_t retained_count() const { return old_bufs_.size(); }
44
+
45
+ // Call only when the stream is guaranteed idle (e.g., after aclrtSynchronizeStream).
46
+ // Drops all retained older buffers, freeing device memory. Current active buffer kept.
47
+ void reset_after_sync() {
48
+ old_bufs_.clear();
49
+ }
50
+
51
+ void clear() {
52
+ old_bufs_.clear();
53
+ buf_ = DeviceBuffer();
54
+ current_size_ = 0;
55
+ grow_count_ = 0;
56
+ }
57
+
58
+ private:
59
+ DeviceBuffer buf_; // current active (largest so far)
60
+ std::vector<DeviceBuffer> old_bufs_; // older, smaller — still live until stream sync
61
+ size_t current_size_ = 0;
62
+ size_t grow_count_ = 0;
63
+ };
64
+
65
+ // Convenience: per-stream RAII guard that acts like a `DeviceBuffer` but draws from pool.
66
+ // Used in aclnn_ops.h wrappers as a drop-in replacement for the local DeviceBuffer.
67
+ class PoolBuffer {
68
+ public:
69
+ // Fallback mode: if pool is nullptr, allocate own buffer (current behavior).
70
+ // Pool mode: return pool's shared pointer.
71
+ PoolBuffer(WorkspacePool* pool, size_t bytes) {
72
+ if (pool) {
73
+ ptr_ = pool->alloc(bytes);
74
+ } else if (bytes > 0) {
75
+ local_.alloc(bytes);
76
+ ptr_ = local_.get();
77
+ }
78
+ }
79
+ void* get() { return ptr_; }
80
+
81
+ private:
82
+ DeviceBuffer local_; // only used when pool is null
83
+ void* ptr_ = nullptr;
84
+ };
scripts/bench_hccl.sh ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ # bench_hccl.sh — HCCL 参数矩阵 benchmark for TG
3
+ #
4
+ # 遍历 HCCL_ALGO × HCCL_BUFFSIZE 组合,每组 N_RUNS 次取中位数,找最佳配置。
5
+ # 固定 prompt + seed=0 + n_predict=200 保证可比性。
6
+ set -u
7
+ cd "$(dirname "$0")/.."
8
+
9
+ MODEL="${MODEL_DIR:-/path/to/Qwen3-235B-A22B-Instruct-2507-BF16}"
10
+ BIN="./build/qwen3-moe-aclnn"
11
+ LAUNCH="./scripts/tp_launch.sh"
12
+ TP="${TP_SIZE:-16}"
13
+ N_PREDICT="${N_PREDICT:-150}"
14
+ N_RUNS="${N_RUNS:-2}"
15
+ PROMPT="${PROMPT:-The history of artificial intelligence spans several decades and}"
16
+ VOCAB="tokenizer_data/vocab.bin"
17
+
18
+ OUT=/tmp/bench_hccl_results.csv
19
+ echo "algo,buffsize,runs,best_tgs" > $OUT
20
+
21
+ run_one() {
22
+ local algo="$1" buf="$2"
23
+ local tgs=()
24
+ for r in $(seq 1 $N_RUNS); do
25
+ export HCCL_ALGO="$algo" HCCL_BUFFSIZE="$buf"
26
+ local out
27
+ out=$(${LAUNCH} ${TP} ${BIN} --model-dir "$MODEL" \
28
+ --prompt "$PROMPT" --n-predict $N_PREDICT \
29
+ --vocab "$VOCAB" --seed 0 2>&1 | grep "decode :" | awk '{print $(NF-2)}')
30
+ tgs+=("${out:-0}")
31
+ done
32
+ local sorted=($(printf '%s\n' "${tgs[@]}" | sort -n))
33
+ local best="${sorted[-1]}"
34
+ local csv="$algo,$buf,${tgs[*]},$best"
35
+ echo "$csv" | sed 's/ /|/g' >> $OUT
36
+ printf " %-22s buf=%-4s %s best=%s\n" \
37
+ "${algo:-(auto)}" "$buf" "${tgs[*]}" "$best"
38
+ }
39
+
40
+ # Matrix
41
+ ALGOS=("" "level0:ring" "level0:fullmesh")
42
+ BUFSIZES=("100" "200" "400")
43
+
44
+ echo "HCCL matrix: ${#ALGOS[@]} algos × ${#BUFSIZES[@]} buffsizes × ${N_RUNS} runs each"
45
+ echo "Results → $OUT"
46
+ echo ""
47
+
48
+ for algo in "${ALGOS[@]}"; do
49
+ for buf in "${BUFSIZES[@]}"; do
50
+ run_one "$algo" "$buf"
51
+ done
52
+ done
53
+
54
+ echo ""
55
+ echo "====== Summary (sorted by best TG) ======"
56
+ (head -1 $OUT; tail -n +2 $OUT | sort -t, -k4 -gr) | column -t -s,
scripts/bench_hccl_adv.sh ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ # bench_hccl_adv.sh — 进阶 HCCL 参数调优
3
+ # 在已确定的 ring:200 baseline 上加入 OP_EXPANSION_MODE=AIV 等 knob
4
+ set -u
5
+ cd "$(dirname "$0")/.."
6
+
7
+ MODEL="${MODEL_DIR:-/path/to/Qwen3-235B-A22B-Instruct-2507-BF16}"
8
+ BIN="./build/qwen3-moe-aclnn"
9
+ LAUNCH="./scripts/tp_launch.sh"
10
+ TP=16
11
+ N_PREDICT=200
12
+ N_RUNS=2
13
+ PROMPT="The history of artificial intelligence spans several decades and"
14
+ VOCAB="tokenizer_data/vocab.bin"
15
+
16
+ OUT=/tmp/bench_hccl_adv.csv
17
+ echo "config,run1,run2,best,median" > $OUT
18
+
19
+ run_one() {
20
+ local name="$1"; shift
21
+ # remaining args are env assignments: KEY=VALUE ...
22
+ local tgs=()
23
+ for r in $(seq 1 $N_RUNS); do
24
+ local out
25
+ # set env vars for this run
26
+ local env_cmd=""
27
+ for a in "$@"; do env_cmd="$env_cmd $a"; done
28
+ out=$(env HCCL_ALGO=level0:ring HCCL_BUFFSIZE=200 $@ \
29
+ ${LAUNCH} ${TP} ${BIN} --model-dir "$MODEL" \
30
+ --prompt "$PROMPT" --n-predict $N_PREDICT \
31
+ --vocab "$VOCAB" --seed 0 2>&1 | grep "decode :" | awk '{print $(NF-2)}')
32
+ tgs+=("${out:-0}")
33
+ done
34
+ local sorted=($(printf '%s\n' "${tgs[@]}" | sort -n))
35
+ local best="${sorted[-1]}"
36
+ local median="${sorted[$((${#sorted[@]}/2))]}"
37
+ echo "$name,${tgs[0]},${tgs[1]},$best,$median" >> $OUT
38
+ printf " %-40s %s best=%s median=%s\n" "$name" "${tgs[*]}" "$best" "$median"
39
+ }
40
+
41
+ echo "Adv HCCL bench: baseline ring:200 + additional knobs"
42
+ echo "Results → $OUT"
43
+ echo ""
44
+
45
+ run_one "baseline (ring+200 only)"
46
+
47
+ run_one "+ OP_EXPANSION_MODE=AIV" HCCL_OP_EXPANSION_MODE=AIV
48
+ run_one "+ OP_BASE_FFTS_MODE=1" HCCL_OP_BASE_FFTS_MODE_ENABLE=1
49
+ run_one "+ OP_EXPANSION=AIV + FFTS=1" HCCL_OP_EXPANSION_MODE=AIV HCCL_OP_BASE_FFTS_MODE_ENABLE=1
50
+ run_one "+ OP_EXPANSION=AIV + BUF=256" HCCL_OP_EXPANSION_MODE=AIV HCCL_BUFFSIZE=256
51
+ run_one "+ OP_EXPANSION=AIV + BUF=512" HCCL_OP_EXPANSION_MODE=AIV HCCL_BUFFSIZE=512
52
+ run_one "+ OP_EXPANSION=AIV + ALGO=fullmesh" HCCL_OP_EXPANSION_MODE=AIV HCCL_ALGO=level0:fullmesh
53
+
54
+ echo ""
55
+ echo "====== Sorted by best TG ======"
56
+ (head -1 $OUT; tail -n +2 $OUT | sort -t, -k4 -gr) | column -t -s,
scripts/bench_hccl_adv2.sh ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ # bench_hccl_adv2.sh — layer 2 env knob exploration on top of AIV+FFTS=1 baseline.
3
+ # Target: break past 25 t/s MUST barrier.
4
+ set -u
5
+ cd "$(dirname "$0")/.."
6
+
7
+ MODEL="${MODEL_DIR:-/path/to/Qwen3-235B-A22B-Instruct-2507-BF16}"
8
+ BIN="./build/qwen3-moe-aclnn"
9
+ LAUNCH="./scripts/tp_launch.sh"
10
+ TP=16
11
+ N_PREDICT=200
12
+ N_RUNS=3
13
+ LONG_PROMPT="Write a very long detailed essay about artificial intelligence, machine learning, deep learning and their applications in modern society. Include historical context, current state of the art, and future predictions."
14
+ VOCAB="tokenizer_data/vocab.bin"
15
+
16
+ OUT=/tmp/bench_hccl_adv2.csv
17
+ echo "config,runs,best,median" > $OUT
18
+
19
+ run_one() {
20
+ local name="$1"; shift
21
+ local tgs=()
22
+ for r in $(seq 1 $N_RUNS); do
23
+ local out
24
+ out=$(env HCCL_ALGO=level0:ring HCCL_BUFFSIZE=200 \
25
+ HCCL_OP_EXPANSION_MODE=AIV HCCL_OP_BASE_FFTS_MODE_ENABLE=1 \
26
+ "$@" \
27
+ ${LAUNCH} ${TP} ${BIN} --model-dir "$MODEL" \
28
+ --prompt "$LONG_PROMPT" --n-predict $N_PREDICT \
29
+ --vocab "$VOCAB" --seed 0 --no-stream 2>&1 \
30
+ | grep "decode :" | awk '{print $(NF-2)}')
31
+ tgs+=("${out:-0}")
32
+ done
33
+ local sorted=($(printf '%s\n' "${tgs[@]}" | sort -n))
34
+ local best="${sorted[-1]}"
35
+ local median="${sorted[$((${#sorted[@]}/2))]}"
36
+ echo "$name,${tgs[*]},$best,$median" | tr ' ' '|' | sed 's/|/,/' | sed 's/|/ /g' >> $OUT
37
+ printf " %-40s %s best=%s median=%s\n" "$name" "${tgs[*]}" "$best" "$median"
38
+ }
39
+
40
+ echo "Bench: AIV+FFTS baseline + single additional knob"
41
+ echo "$N_RUNS runs × $N_PREDICT tokens"
42
+ echo ""
43
+
44
+ run_one "baseline (AIV + FFTS)"
45
+
46
+ run_one "+ TASK_QUEUE_ENABLE=1" TASK_QUEUE_ENABLE=1
47
+ run_one "+ TASK_QUEUE_ENABLE=2" TASK_QUEUE_ENABLE=2
48
+ run_one "+ HCCL_BUFFSIZE=256" HCCL_BUFFSIZE=256
49
+ run_one "+ HCCL_DETERMINISTIC=false" HCCL_DETERMINISTIC=false
50
+ run_one "+ HCCL_INTRA_ROCE_ENABLE=1" HCCL_INTRA_ROCE_ENABLE=1
51
+ run_one "+ HCCL_CLUSTER_TIMEOUT=600" HCCL_CLUSTER_TIMEOUT=600
52
+ run_one "+ ASCEND_LAUNCH_BLOCKING=0" ASCEND_LAUNCH_BLOCKING=0
53
+
54
+ echo ""
55
+ echo "====== Sorted by best TG ======"
56
+ (head -1 $OUT; tail -n +2 $OUT | sort -t, -k3 -gr) | column -t -s,
scripts/bench_pld.sh ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ # bench_pld.sh — sweep K × n-gram with corrected causal-with-past mask.
3
+ # Measures TG + accept rate stability across N_RUNS per config.
4
+ set -u
5
+ cd "$(dirname "$0")/.."
6
+
7
+ MODEL="${MODEL_DIR:-/path/to/Qwen3-235B-A22B-Instruct-2507-BF16}"
8
+ BIN="./build/qwen3-moe-aclnn"
9
+ LAUNCH="./scripts/tp_launch.sh"
10
+ TP=16
11
+ N_PREDICT=200
12
+ N_RUNS="${N_RUNS:-3}"
13
+ PROMPT="${PROMPT:-Write a long Python function that computes the Fibonacci sequence with memoization, extensive comments, and type hints.}"
14
+ VOCAB="tokenizer_data/vocab.bin"
15
+
16
+ OUT=/tmp/bench_pld.csv
17
+ echo "k,ngram,run_tgs,best,median,avg_accept" > $OUT
18
+
19
+ run_one() {
20
+ local k="$1" ng="$2"
21
+ local tgs=() accs=()
22
+ for r in $(seq 1 $N_RUNS); do
23
+ local output
24
+ output=$(${LAUNCH} ${TP} ${BIN} --model-dir "$MODEL" \
25
+ --prompt "$PROMPT" --n-predict $N_PREDICT --max-seq 512 \
26
+ --vocab "$VOCAB" --seed 0 --no-stream \
27
+ --pld --pld-k $k --pld-ngram $ng 2>&1)
28
+ local tg
29
+ tg=$(echo "$output" | grep "decode :" | awk '{print $(NF-2)}')
30
+ local acc
31
+ acc=$(echo "$output" | grep "\[pld\]" | grep -oE "avg=[0-9.]+" | cut -d= -f2)
32
+ tgs+=("${tg:-0}")
33
+ accs+=("${acc:-0}")
34
+ done
35
+ local sorted=($(printf '%s\n' "${tgs[@]}" | sort -n))
36
+ local n=${#sorted[@]}
37
+ local best="${sorted[-1]}"
38
+ local median="${sorted[$((n/2))]}"
39
+ local accs_avg=$(printf '%s\n' "${accs[@]}" | awk '{s+=$1} END {printf "%.2f", s/NR}')
40
+ echo "$k,$ng,$(IFS=/; echo "${tgs[*]}"),$best,$median,$accs_avg" >> $OUT
41
+ printf " K=%-2d ng=%-1d runs=[%s] best=%s median=%s accept_avg=%s\n" \
42
+ "$k" "$ng" "${tgs[*]}" "$best" "$median" "$accs_avg"
43
+ }
44
+
45
+ echo "PLD sweep on '$PROMPT' ($N_RUNS runs × $N_PREDICT tokens)"
46
+ echo ""
47
+
48
+ for k in 2 4 6 8 12; do
49
+ for ng in 1 2 3; do
50
+ run_one $k $ng
51
+ done
52
+ done
53
+
54
+ # Baseline for reference
55
+ echo ""
56
+ echo "Baseline (no PLD):"
57
+ tgs=()
58
+ for r in $(seq 1 $N_RUNS); do
59
+ tg=$(${LAUNCH} ${TP} ${BIN} --model-dir "$MODEL" \
60
+ --prompt "$PROMPT" --n-predict $N_PREDICT --max-seq 512 \
61
+ --vocab "$VOCAB" --seed 0 --no-stream 2>&1 | grep "decode :" | awk '{print $(NF-2)}')
62
+ tgs+=("${tg:-0}")
63
+ done
64
+ sorted=($(printf '%s\n' "${tgs[@]}" | sort -n))
65
+ echo " baseline: ${tgs[*]} median=${sorted[$((${#sorted[@]}/2))]}"
66
+
67
+ echo ""
68
+ echo "====== Sorted by median TG ======"
69
+ (head -1 $OUT; tail -n +2 $OUT | sort -t, -k5 -gr) | column -t -s,
scripts/bench_pld_k.sh ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ # bench_pld_k.sh — isolated K sweep with FIXED K (no adaptive) to characterize raw K effect.
3
+ # Larger K = more draft candidates per verify. Peak observed accept=7.38 suggests K=8 not saturated.
4
+ set -u
5
+ cd "$(dirname "$0")/.."
6
+
7
+ MODEL="${MODEL_DIR:-/path/to/Qwen3-235B-A22B-Instruct-2507-BF16}"
8
+ BIN="./build/qwen3-moe-aclnn"
9
+ LAUNCH="./scripts/tp_launch.sh"
10
+ TP=16
11
+ N_PREDICT=200
12
+ N_RUNS=3
13
+ PROMPT="Write a long Python function that computes the Fibonacci sequence with memoization, extensive comments, and type hints."
14
+ VOCAB="tokenizer_data/vocab.bin"
15
+
16
+ OUT=/tmp/bench_pld_k.csv
17
+ echo "k,runs,median,max,avg_accept" > $OUT
18
+
19
+ for K in 4 6 8 10 12 16; do
20
+ tgs=() accs=()
21
+ for r in $(seq 1 $N_RUNS); do
22
+ out=$(${LAUNCH} ${TP} ${BIN} --model-dir "$MODEL" \
23
+ --prompt "$PROMPT" --n-predict $N_PREDICT --max-seq 512 \
24
+ --vocab "$VOCAB" --seed 0 --no-stream \
25
+ --pld --pld-k $K --pld-ngram 1 --pld-fixed-k 2>&1)
26
+ tg=$(echo "$out" | grep "decode :" | awk '{print $(NF-2)}')
27
+ acc=$(echo "$out" | grep "\[pld\]" | grep -oE "avg=[0-9.]+" | cut -d= -f2)
28
+ tgs+=("${tg:-0}")
29
+ accs+=("${acc:-0}")
30
+ done
31
+ sorted=($(printf '%s\n' "${tgs[@]}" | sort -n))
32
+ median="${sorted[$((${#sorted[@]}/2))]}"
33
+ max="${sorted[-1]}"
34
+ accs_avg=$(printf '%s\n' "${accs[@]}" | awk '{s+=$1} END {printf "%.2f", s/NR}')
35
+ echo "$K,$(IFS=/; echo "${tgs[*]}"),$median,$max,$accs_avg" >> $OUT
36
+ printf " K=%-2d runs=[%s] median=%s max=%s accept=%s\n" "$K" "${tgs[*]}" "$median" "$max" "$accs_avg"
37
+ done
38
+
39
+ echo ""
40
+ echo "====== Sorted by median ======"
41
+ (head -1 $OUT; tail -n +2 $OUT | sort -t, -k3 -gr) | column -t -s,
scripts/bench_pld_safe.sh ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ # bench_pld_safe.sh — PLD benchmark with output correctness check.
3
+ # Unlike bench_tg.sh (which only reports TG numbers), this wrapper also inspects the
4
+ # generated text for degeneration signals (consecutive identical tokens / very low
5
+ # distinct-token ratio in the tail) and flags runs whose high TG came from dead-loop
6
+ # output rather than real acceleration.
7
+ #
8
+ # Usage: ./scripts/bench_pld_safe.sh [N_RUNS] [PROMPT_FILE]
9
+ # Prompts with "|" separator: "tag|prompt text"
10
+ # Default: tests multiple prompt classes and reports which ones PLD helps safely.
11
+ set -u
12
+ cd "$(dirname "$0")/.."
13
+
14
+ MODEL="${MODEL_DIR:-/path/to/Qwen3-235B-A22B-Instruct-2507-BF16}"
15
+ BIN="./build/qwen3-moe-aclnn"
16
+ N_RUNS="${1:-3}"
17
+ N_PREDICT="${N_PREDICT:-120}"
18
+ VOCAB="tokenizer_data/vocab.bin"
19
+
20
+ # Default prompt suite: one per class. Override via PROMPTS env or arg 2 (file with "tag|prompt" per line).
21
+ default_prompts=(
22
+ "story|Once upon a time, in a small village,"
23
+ "factual|The capital of France is"
24
+ "code|Write a Python function that computes Fibonacci."
25
+ "essay|The history of artificial intelligence spans several decades and"
26
+ )
27
+
28
+ if [ "${2:-}" != "" ] && [ -f "${2:-}" ]; then
29
+ mapfile -t prompts < "$2"
30
+ else
31
+ prompts=("${default_prompts[@]}")
32
+ fi
33
+
34
+ # ----- Correctness classifier -----
35
+ # Reads generated text from stdin, returns:
36
+ # OK — no loop signals
37
+ # LOOP_N — N+ consecutive identical non-space words detected
38
+ # LOW_DIVERSITY — tail 40 words have < 10 distinct words (heavy repetition)
39
+ classify_output() {
40
+ awk '
41
+ {
42
+ # Tokenize on whitespace; strip punct at edges for comparison.
43
+ n = split($0, w, /[[:space:]]+/);
44
+ for (i = 1; i <= n; i++) {
45
+ gsub(/^[[:punct:]]+|[[:punct:]]+$/, "", w[i]);
46
+ if (w[i] == "") continue;
47
+ words[++nw] = tolower(w[i]);
48
+ }
49
+ }
50
+ END {
51
+ if (nw < 5) { print "OK"; exit }
52
+ # consecutive-same detection
53
+ run = 1; max_run = 1;
54
+ for (i = 2; i <= nw; i++) {
55
+ if (words[i] == words[i-1]) { run++; if (run > max_run) max_run = run; }
56
+ else run = 1;
57
+ }
58
+ if (max_run >= 6) { printf "LOOP_%d\n", max_run; exit }
59
+
60
+ # tail diversity: last 40 words
61
+ tail_start = nw - 39; if (tail_start < 1) tail_start = 1;
62
+ delete seen;
63
+ distinct = 0;
64
+ for (i = tail_start; i <= nw; i++) {
65
+ if (!(words[i] in seen)) { seen[words[i]] = 1; distinct++; }
66
+ }
67
+ tail_n = nw - tail_start + 1;
68
+ if (tail_n >= 20 && distinct < 10) {
69
+ printf "LOW_DIVERSITY_%d/%d\n", distinct, tail_n;
70
+ exit;
71
+ }
72
+ print "OK";
73
+ }'
74
+ }
75
+
76
+ run_once() {
77
+ local prompt="$1"
78
+ local extra_flags="$2"
79
+ # Launch. The binary prints to stdout: rank/cli headers, runner loading lines,
80
+ # generated text (--no-stream), then perf lines. pld/warn go to stderr.
81
+ local stdout_file=$(mktemp)
82
+ local stderr_file=$(mktemp)
83
+ # Ensure no lockfile leftover.
84
+ ssh_cleanup_lockfile
85
+ ./scripts/tp_launch.sh 16 $BIN --model-dir "$MODEL" \
86
+ --prompt "$prompt" --n-predict $N_PREDICT \
87
+ --vocab "$VOCAB" --seed 0 --no-stream --temperature 0 \
88
+ $extra_flags 1>"$stdout_file" 2>"$stderr_file"
89
+ # TG lives on stdout (from printf in binary).
90
+ local tg=$(grep "\[perf\] decode" "$stdout_file" | awk '{print $(NF-2)}')
91
+ # Generated text: the line that begins with the prompt (--no-stream echoes prompt+text).
92
+ local gen_text=$(grep -F -- "$prompt" "$stdout_file" | grep -v '^\[' | tail -1)
93
+ local stripped="${gen_text#$prompt}"
94
+ local verdict=$(echo "$stripped" | classify_output)
95
+ local has_warn=""
96
+ if grep -q "\[warn\]" "$stderr_file"; then has_warn="WARN"; fi
97
+ local pld_line=$(grep "\[pld\]" "$stderr_file" | tail -1 | sed 's/^\[pld\] //')
98
+ rm -f "$stdout_file" "$stderr_file"
99
+ echo "${tg:-0}|${verdict}|${has_warn}|${pld_line}"
100
+ }
101
+
102
+ ssh_cleanup_lockfile() {
103
+ rm -f /tmp/hccl_root_info.bin 2>/dev/null || true
104
+ }
105
+
106
+ bench_prompt() {
107
+ local tag="$1"; local prompt="$2"; local flags="$3"
108
+ echo ""
109
+ echo "=== [$tag] $(echo "$prompt" | head -c 50)... (flags: ${flags:-none}) ==="
110
+ local tgs=() verdicts=() warns=() plds=()
111
+ for r in $(seq 1 $N_RUNS); do
112
+ result=$(run_once "$prompt" "$flags")
113
+ IFS='|' read -r tg verdict warn pld <<< "$result"
114
+ printf " run %d: TG=%s verdict=%s %s\n" "$r" "$tg" "$verdict" "$warn"
115
+ [ -n "$pld" ] && printf " %s\n" "$pld"
116
+ tgs+=("${tg:-0}"); verdicts+=("$verdict"); warns+=("$warn")
117
+ rm -f /tmp/hccl_root_info.bin
118
+ done
119
+ # Split good vs degraded
120
+ local good_tgs=() bad_tgs=()
121
+ for i in "${!tgs[@]}"; do
122
+ if [ "${verdicts[$i]}" = "OK" ]; then good_tgs+=("${tgs[$i]}"); else bad_tgs+=("${tgs[$i]}"); fi
123
+ done
124
+ local n_good=${#good_tgs[@]}
125
+ local n_bad=${#bad_tgs[@]}
126
+ echo " → $n_good/$N_RUNS OK, $n_bad/$N_RUNS degraded"
127
+ if [ $n_good -gt 0 ]; then
128
+ local mean=$(printf '%s\n' "${good_tgs[@]}" | awk '{s+=$1} END {printf "%.2f", s/NR}')
129
+ echo " → OK mean TG: $mean t/s (values: ${good_tgs[*]})"
130
+ fi
131
+ if [ $n_bad -gt 0 ]; then
132
+ local bad_mean=$(printf '%s\n' "${bad_tgs[@]}" | awk '{s+=$1} END {printf "%.2f", s/NR}')
133
+ echo " → degraded mean TG: $bad_mean t/s (DO NOT REPORT as speedup) (values: ${bad_tgs[*]})"
134
+ fi
135
+ }
136
+
137
+ echo "bench_pld_safe: $N_RUNS runs × $N_PREDICT tokens per prompt; comparing [no-pld, pld+guard, pld+no-guard]"
138
+
139
+ for entry in "${prompts[@]}"; do
140
+ tag="${entry%%|*}"
141
+ prompt="${entry#*|}"
142
+ bench_prompt "$tag/base" "$prompt" ""
143
+ bench_prompt "$tag/pld+guard" "$prompt" "--pld"
144
+ bench_prompt "$tag/pld-raw" "$prompt" "--pld --pld-no-guard"
145
+ done
146
+
147
+ echo ""
148
+ echo "=========================================================="
149
+ echo "Interpretation:"
150
+ echo " OK mean TG is the only honest number to report."
151
+ echo " Any 'degraded' result with high TG is a dead-loop artifact."
152
+ echo " Expected: pld+guard matches or beats base on creative/story prompts,"
153
+ echo " matches base on factual/code prompts (drafts rejected → fallback to single decode)."
154
+ echo " pld-raw (no guard) on repetitive prompts produces 'degraded' with high TG."
scripts/bench_tg.sh ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ # bench_tg.sh — stable TG measurement: N runs × 200 tokens, drop cold-starts, report median.
3
+ #
4
+ # Usage: ./scripts/bench_tg.sh [N_RUNS] (default 5)
5
+ # LCA_WARMUP=3 ./scripts/bench_tg.sh (with warmup enabled)
6
+ set -u
7
+ cd "$(dirname "$0")/.."
8
+
9
+ MODEL="${MODEL_DIR:-/path/to/Qwen3-235B-A22B-Instruct-2507-BF16}"
10
+ BIN="./build/qwen3-moe-aclnn"
11
+ N_RUNS="${1:-5}"
12
+ N_PREDICT="${N_PREDICT:-200}"
13
+ PROMPT="The history of artificial intelligence spans several decades and"
14
+ VOCAB="tokenizer_data/vocab.bin"
15
+
16
+ echo "bench_tg: $N_RUNS runs × $N_PREDICT tokens (LCA_WARMUP=${LCA_WARMUP:-0})"
17
+ tgs=()
18
+ for r in $(seq 1 $N_RUNS); do
19
+ local_out=$(./scripts/tp_launch.sh 16 $BIN --model-dir "$MODEL" \
20
+ --prompt "$PROMPT" --n-predict $N_PREDICT \
21
+ --vocab "$VOCAB" --seed 0 2>&1 | grep "decode :" | awk '{print $(NF-2)}')
22
+ printf " run %d: %s t/s\n" "$r" "$local_out"
23
+ tgs+=("${local_out:-0}")
24
+ done
25
+
26
+ echo ""
27
+ echo "====== Summary ======"
28
+ sorted=($(printf '%s\n' "${tgs[@]}" | sort -n))
29
+ n=${#sorted[@]}
30
+ mid=$((n / 2))
31
+ median="${sorted[$mid]}"
32
+ min="${sorted[0]}"
33
+ max="${sorted[-1]}"
34
+ mean=$(printf '%s\n' "${tgs[@]}" | awk '{s+=$1} END {printf "%.2f", s/NR}')
35
+
36
+ echo " all : ${tgs[*]}"
37
+ echo " min : $min t/s"
38
+ echo " median : $median t/s"
39
+ echo " mean : $mean t/s"
40
+ echo " max : $max t/s"
scripts/export_vocab.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Export Qwen3 tokenizer vocab to a simple binary format.
3
+
4
+ Format (little-endian):
5
+ u32 num_tokens
6
+ for each id in [0, num_tokens):
7
+ u32 byte_length
8
+ u8[byte_length] utf8_bytes
9
+
10
+ Also emits special_tokens.txt with id + content pairs for reference.
11
+ """
12
+ import json, sys, struct, os
13
+
14
+ model_dir = sys.argv[1] if len(sys.argv) > 1 else '/path/to/Qwen3-235B-A22B-Instruct-2507-BF16'
15
+ out_dir = sys.argv[2] if len(sys.argv) > 2 else 'tokenizer_data'
16
+ os.makedirs(out_dir, exist_ok=True)
17
+
18
+ with open(os.path.join(model_dir, 'tokenizer.json'), 'r') as f:
19
+ tok = json.load(f)
20
+
21
+ # Byte-level decoder map: HF Qwen uses byte-level BPE like GPT-2
22
+ # Each non-ASCII vocab token is a mapping of U+0100..U+017F etc back to raw bytes.
23
+ # For decode we just need the reverse map from printable chars to raw bytes.
24
+ def build_byte_decoder():
25
+ bs = list(range(ord('!'), ord('~')+1)) + list(range(ord('¡'), ord('¬')+1)) + list(range(ord('®'), ord('ÿ')+1))
26
+ cs = bs[:]
27
+ n = 0
28
+ for b in range(2**8):
29
+ if b not in bs:
30
+ bs.append(b)
31
+ cs.append(2**8 + n)
32
+ n += 1
33
+ return {chr(c): bytes([b]) for b, c in zip(bs, cs)}
34
+
35
+ byte_decoder = build_byte_decoder()
36
+
37
+ # Merge vocab + added_tokens into id -> utf8_bytes lookup
38
+ vocab = tok['model']['vocab'] # {token_str: id}
39
+ added = tok.get('added_tokens', []) # list of {id, content, ...}
40
+
41
+ id_to_bytes = {}
42
+ for token, tid in vocab.items():
43
+ # Decode byte-level encoding back to raw utf8 bytes
44
+ raw = b''
45
+ for ch in token:
46
+ if ch in byte_decoder:
47
+ raw += byte_decoder[ch]
48
+ else:
49
+ raw += ch.encode('utf-8')
50
+ id_to_bytes[int(tid)] = raw
51
+
52
+ for a in added:
53
+ # Special tokens stored as raw utf8
54
+ id_to_bytes[int(a['id'])] = a['content'].encode('utf-8')
55
+
56
+ max_id = max(id_to_bytes.keys())
57
+ num = max_id + 1
58
+ print(f"max_id = {max_id}, num_tokens = {num}")
59
+ print(f"num_special_tokens = {len(added)}")
60
+
61
+ # Write vocab.bin
62
+ vocab_path = os.path.join(out_dir, 'vocab.bin')
63
+ with open(vocab_path, 'wb') as f:
64
+ f.write(struct.pack('<I', num))
65
+ for i in range(num):
66
+ b = id_to_bytes.get(i, b'')
67
+ f.write(struct.pack('<I', len(b)))
68
+ f.write(b)
69
+ print(f"Wrote {vocab_path} ({os.path.getsize(vocab_path)} bytes)")
70
+
71
+ # Write special tokens
72
+ with open(os.path.join(out_dir, 'special_tokens.txt'), 'w') as f:
73
+ for a in added:
74
+ f.write(f"{a['id']}\t{a['content']}\n")
75
+ print(f"Wrote special_tokens.txt")
76
+
77
+ # Verify via a known prompt
78
+ from transformers import AutoTokenizer
79
+ atok = AutoTokenizer.from_pretrained(model_dir)
80
+ test = "The capital of France is"
81
+ ids = atok.encode(test)
82
+ print(f"\nTest encode '{test}' -> {ids}")
83
+ decoded = ''.join(id_to_bytes.get(i, b'?').decode('utf-8', errors='replace') for i in ids)
84
+ print(f"Our decode: '{decoded}'")
85
+ print(f"HF decode: '{atok.decode(ids)}'")
scripts/gen_attention_reference.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Generate a single-layer attention forward reference for Qwen3-235B layer 0.
3
+
4
+ Input: token ids (representing "The capital of France is")
5
+ Output: hidden_states after layer 0 attention (residual already added).
6
+ Also dumps all intermediate tensors for step-wise debugging.
7
+ """
8
+ import os, json, math, struct
9
+ import torch
10
+ import torch_npu
11
+ from safetensors.torch import load_file
12
+
13
+ torch.npu.set_device(0)
14
+ torch.set_grad_enabled(False)
15
+
16
+ MODEL_DIR = '/path/to/Qwen3-235B-A22B-Instruct-2507-BF16'
17
+ OUT_DIR = 'tests/attn_data'
18
+ os.makedirs(OUT_DIR, exist_ok=True)
19
+
20
+ cfg = json.load(open(os.path.join(MODEL_DIR, 'config.json')))
21
+ D = cfg['hidden_size'] # 4096
22
+ Hq = cfg['num_attention_heads'] # 64
23
+ Hkv = cfg['num_key_value_heads'] # 4
24
+ Dh = cfg['head_dim'] # 128
25
+ Q_DIM = Hq * Dh # 8192
26
+ KV_DIM = Hkv * Dh # 512
27
+ eps = cfg['rms_norm_eps']
28
+ theta = cfg['rope_theta'] # 5e6 for Qwen3-235B
29
+
30
+ # ---- Find which safetensors shard contains layer 0 attention + input_layernorm ----
31
+ idx = json.load(open(os.path.join(MODEL_DIR, 'model.safetensors.index.json')))
32
+ wm = idx['weight_map']
33
+
34
+ needed = [
35
+ 'model.embed_tokens.weight',
36
+ 'model.layers.0.input_layernorm.weight',
37
+ 'model.layers.0.self_attn.q_proj.weight',
38
+ 'model.layers.0.self_attn.k_proj.weight',
39
+ 'model.layers.0.self_attn.v_proj.weight',
40
+ 'model.layers.0.self_attn.o_proj.weight',
41
+ 'model.layers.0.self_attn.q_norm.weight',
42
+ 'model.layers.0.self_attn.k_norm.weight',
43
+ ]
44
+ shards = sorted({wm[n] for n in needed})
45
+ print("Need to load shards:", shards)
46
+
47
+ weights = {}
48
+ for sh in shards:
49
+ t = load_file(os.path.join(MODEL_DIR, sh))
50
+ for n in needed:
51
+ if n in t:
52
+ weights[n] = t[n].to('npu')
53
+ print("loaded:", list(weights.keys()))
54
+
55
+ # ---- Forward ----
56
+ # Input tokens (from tokenizer: "The capital of France is")
57
+ token_ids = torch.tensor([785, 6722, 315, 9625, 374], dtype=torch.long).npu()
58
+ S = token_ids.shape[0]
59
+ print(f"S = {S}")
60
+
61
+ # Embedding lookup
62
+ x = weights['model.embed_tokens.weight'][token_ids] # [S, D]
63
+ x = x.unsqueeze(0) # [1, S, D]
64
+ print("embed x:", x.shape, x.dtype)
65
+
66
+ # Residual
67
+ residual = x
68
+
69
+ # Input layernorm (RMSNorm)
70
+ ln = weights['model.layers.0.input_layernorm.weight']
71
+ xn, _ = torch_npu.npu_rms_norm(x, ln, epsilon=eps)
72
+ print("after_input_norm xn:", xn.shape)
73
+
74
+ # Q/K/V projections
75
+ Wq = weights['model.layers.0.self_attn.q_proj.weight']
76
+ Wk = weights['model.layers.0.self_attn.k_proj.weight']
77
+ Wv = weights['model.layers.0.self_attn.v_proj.weight']
78
+ q = torch.matmul(xn, Wq.t()) # [1, S, Q_DIM]
79
+ k = torch.matmul(xn, Wk.t()) # [1, S, KV_DIM]
80
+ v = torch.matmul(xn, Wv.t())
81
+
82
+ # Reshape to heads
83
+ q = q.view(1, S, Hq, Dh)
84
+ k = k.view(1, S, Hkv, Dh)
85
+ v = v.view(1, S, Hkv, Dh)
86
+
87
+ # Per-head RMSNorm on head_dim (Qwen3 specific)
88
+ qn_w = weights['model.layers.0.self_attn.q_norm.weight'] # [Dh]
89
+ kn_w = weights['model.layers.0.self_attn.k_norm.weight']
90
+ q_normed, _ = torch_npu.npu_rms_norm(q, qn_w, epsilon=eps)
91
+ k_normed, _ = torch_npu.npu_rms_norm(k, kn_w, epsilon=eps)
92
+
93
+ # RoPE: compute cos/sin for positions [0, S)
94
+ position_ids = torch.arange(S, device='npu').unsqueeze(0) # [1, S]
95
+ inv_freq = 1.0 / (theta ** (torch.arange(0, Dh, 2, dtype=torch.float32).npu() / Dh))
96
+ freqs = position_ids.float().unsqueeze(-1) * inv_freq.unsqueeze(0).unsqueeze(0) # [1, S, Dh/2]
97
+ # Concat (half, half) to get [1, S, Dh]
98
+ emb = torch.cat([freqs, freqs], dim=-1)
99
+ cos = emb.cos().to(torch.bfloat16) # [1, S, Dh]
100
+ sin = emb.sin().to(torch.bfloat16)
101
+
102
+ # Apply RoPE — npu_apply_rotary_pos_emb expects BSND layout
103
+ # cos/sin shape: [1, S, 1, Dh] for broadcast over heads
104
+ cos_b = cos.unsqueeze(2)
105
+ sin_b = sin.unsqueeze(2)
106
+ q_roped, k_roped = torch_npu.npu_apply_rotary_pos_emb(q_normed, k_normed, cos_b, sin_b)
107
+
108
+ # Flatten for FIAS (BSH layout)
109
+ q_bsh = q_roped.reshape(1, S, Q_DIM)
110
+ k_bsh = k_roped.reshape(1, S, KV_DIM)
111
+ v_bsh = v.reshape(1, S, KV_DIM)
112
+
113
+ # FIAS with causal mask for prefill
114
+ scale = 1.0 / math.sqrt(Dh)
115
+ # sparse_mode=3 requires fixed 2048×2048 mask
116
+ MASK_SIZE = 2048
117
+ mask = torch.triu(torch.ones(MASK_SIZE, MASK_SIZE, dtype=torch.bool, device='npu'), diagonal=1)
118
+ mask = mask.view(1, 1, MASK_SIZE, MASK_SIZE)
119
+ attn_out, _ = torch_npu.npu_fused_infer_attention_score(
120
+ q_bsh, k_bsh, v_bsh,
121
+ num_heads=Hq,
122
+ num_key_value_heads=Hkv,
123
+ scale=scale,
124
+ input_layout="BSH",
125
+ sparse_mode=3,
126
+ atten_mask=mask,
127
+ actual_seq_lengths=[S],
128
+ actual_seq_lengths_kv=[S],
129
+ )
130
+ print("attn_out:", attn_out.shape) # [1, S, Q_DIM]
131
+
132
+ # Output projection
133
+ Wo = weights['model.layers.0.self_attn.o_proj.weight']
134
+ o = torch.matmul(attn_out, Wo.t()) # [1, S, D]
135
+
136
+ # Residual add
137
+ out = residual + o
138
+ print("out:", out.shape, out[0, 0, :4].float().tolist())
139
+
140
+ # ---- Dump ----
141
+ def dump(name, t):
142
+ p = os.path.join(OUT_DIR, name + '.bin')
143
+ a = t.contiguous().cpu().view(torch.int16).numpy().astype('int16')
144
+ open(p, 'wb').write(a.tobytes())
145
+
146
+ # Save token_ids
147
+ with open(os.path.join(OUT_DIR, 'token_ids.bin'), 'wb') as f:
148
+ f.write(struct.pack('<i', S))
149
+ for tid in token_ids.cpu().tolist():
150
+ f.write(struct.pack('<i', tid))
151
+
152
+ # Save inputs
153
+ dump('x_input', x) # embed result
154
+ dump('x_normed', xn)
155
+ dump('q_normed', q_normed)
156
+ dump('k_normed', k_normed)
157
+ dump('q_roped', q_roped)
158
+ dump('k_roped', k_roped)
159
+ dump('cos', cos)
160
+ dump('sin', sin)
161
+ dump('attn_out', attn_out)
162
+ dump('final_out', out)
163
+ # Save weights used (dtype=BF16)
164
+ for name, path_name in [
165
+ ('model.layers.0.input_layernorm.weight', 'w_input_norm'),
166
+ ('model.layers.0.self_attn.q_proj.weight', 'w_q_proj'),
167
+ ('model.layers.0.self_attn.k_proj.weight', 'w_k_proj'),
168
+ ('model.layers.0.self_attn.v_proj.weight', 'w_v_proj'),
169
+ ('model.layers.0.self_attn.o_proj.weight', 'w_o_proj'),
170
+ ('model.layers.0.self_attn.q_norm.weight', 'w_q_norm'),
171
+ ('model.layers.0.self_attn.k_norm.weight', 'w_k_norm'),
172
+ ]:
173
+ dump(path_name, weights[name])
174
+
175
+ with open(os.path.join(OUT_DIR, 'shape.txt'), 'w') as f:
176
+ f.write(f"S={S}\nD={D}\nHq={Hq}\nHkv={Hkv}\nDh={Dh}\nQ_DIM={Q_DIM}\nKV_DIM={KV_DIM}\neps={eps}\ntheta={theta}\n")
177
+
178
+ print("\nAll dumps in:", OUT_DIR)
179
+ print("Final output first 4:", out[0, 0, :4].float().cpu().tolist())
scripts/gen_gmm_reference.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Generate a GMM reference case using torch_npu.npu_grouped_matmul.
3
+
4
+ Dumps: x, w (unpermuted ggml-like layout), group_list, and y_ref to binary files
5
+ for the C++ POC to load and validate against.
6
+ """
7
+ import os
8
+ import sys
9
+ import numpy as np
10
+ import struct
11
+
12
+ # Enable torch_npu
13
+ os.environ.setdefault('LD_LIBRARY_PATH', '')
14
+ import torch
15
+ import torch_npu
16
+
17
+ torch.npu.set_device(0)
18
+ torch.manual_seed(42)
19
+
20
+ # Toy Qwen3-like MoE shape: D=64 hidden, I=32 intermediate, E=8 experts, N*K=16 expanded tokens
21
+ # (small enough to eyeball; large enough to catch layout bugs)
22
+ D, I, E, TOTAL = 64, 32, 8, 16
23
+
24
+ # Input x: [TOTAL, D] BF16 — expanded routed tokens
25
+ x = torch.randn(TOTAL, D, dtype=torch.bfloat16).npu()
26
+
27
+ # Weight w: per-expert [I, D] BF16 — gate/up has this shape in HF
28
+ # We will stack into [E, I, D] and also provide [E, D, I] permuted for comparison
29
+ w_per_expert = [torch.randn(I, D, dtype=torch.bfloat16).npu() for _ in range(E)]
30
+ w_stacked_IDL = torch.stack(w_per_expert, dim=0) # [E, I, D]
31
+
32
+ # group_list: counts of tokens per expert, sum = TOTAL
33
+ group_list = torch.tensor([3, 2, 1, 2, 1, 3, 2, 2], dtype=torch.int64).npu()
34
+ assert group_list.sum().item() == TOTAL
35
+
36
+ # Reference: use torch_npu.npu_grouped_matmul
37
+ # Per cann-recipes: weight needs to be in [E, D, I] for matmul y = x @ w (y shape [total, I])
38
+ # i.e. per-expert w is transposed from HF's [I, D] to [D, I]
39
+ w_transposed = w_stacked_IDL.transpose(1, 2).contiguous() # [E, D, I]
40
+
41
+ # Call GMM: y = x @ w, result [TOTAL, I]
42
+ y_ref = torch_npu.npu_grouped_matmul(
43
+ [x], # x list
44
+ [w_transposed], # weight list (transposed)
45
+ group_list=group_list,
46
+ group_type=0,
47
+ group_list_type=1, # counts
48
+ split_item=3 # single-in single-out
49
+ )[0] # unwrap tensor list
50
+
51
+ print("x shape:", x.shape, x.dtype)
52
+ print("w_stacked_IDL shape:", w_stacked_IDL.shape, w_stacked_IDL.dtype)
53
+ print("w_transposed shape:", w_transposed.shape)
54
+ print("group_list:", group_list.cpu().tolist())
55
+ print("y_ref shape:", y_ref.shape)
56
+ print("y_ref[0, 0:4]:", y_ref[0, 0:4].cpu().float().tolist())
57
+
58
+ # Save binary dumps
59
+ out_dir = 'tests/poc_data'
60
+ os.makedirs(out_dir, exist_ok=True)
61
+
62
+ def dump_bf16(name, tensor):
63
+ path = os.path.join(out_dir, name + '.bin')
64
+ arr = tensor.contiguous().cpu().view(torch.int16).numpy().astype('int16')
65
+ with open(path, 'wb') as f:
66
+ f.write(arr.tobytes())
67
+ print(f" wrote {name}.bin: {arr.shape} int16 = BF16 raw, {arr.nbytes} bytes")
68
+
69
+ def dump_int64(name, tensor):
70
+ path = os.path.join(out_dir, name + '.bin')
71
+ arr = tensor.contiguous().cpu().numpy().astype('int64')
72
+ with open(path, 'wb') as f:
73
+ f.write(arr.tobytes())
74
+ print(f" wrote {name}.bin: {arr.shape} int64, {arr.nbytes} bytes")
75
+
76
+ # HF-style weight layout (ggml stores similar): [E, I, D] = what C++ gets from safetensors after stack
77
+ dump_bf16('x', x)
78
+ dump_bf16('w_hf_EID', w_stacked_IDL) # C++ input weight (HF layout)
79
+ dump_bf16('w_ref_EDI', w_transposed) # Already-permuted reference (for debug)
80
+ dump_int64('group_list', group_list)
81
+ dump_bf16('y_ref', y_ref)
82
+
83
+ # Also dump shapes header
84
+ with open(os.path.join(out_dir, 'shapes.txt'), 'w') as f:
85
+ f.write(f"D={D}\nI={I}\nE={E}\nTOTAL={TOTAL}\n")
86
+
87
+ print("\nAll dumps in:", out_dir)
88
+ print("\nTo validate: C++ loads w_hf_EID, permutes [0,2,1] to [E,D,I], NZ-casts, calls GMMV4, "
89
+ "compares output to y_ref.")
scripts/gen_mm_reference.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Generate a linear (y = x @ W.T) reference for a realistic Qwen3 attention shape."""
3
+ import os, struct, torch, torch_npu
4
+ torch.npu.set_device(0)
5
+ torch.manual_seed(7)
6
+
7
+ N, D, OUT = 5, 4096, 8192 # prompt len, hidden, q_dim
8
+
9
+ x = torch.randn(N, D, dtype=torch.bfloat16).npu()
10
+ W = torch.randn(OUT, D, dtype=torch.bfloat16).npu() # HF layout [out, in]
11
+ # y = x @ W.T, shape [N, OUT]
12
+ y_ref = torch.matmul(x, W.t())
13
+
14
+ out_dir = 'tests/mm_data'
15
+ os.makedirs(out_dir, exist_ok=True)
16
+ def dump(name, t):
17
+ p = os.path.join(out_dir, name + '.bin')
18
+ a = t.contiguous().cpu().view(torch.int16).numpy().astype('int16')
19
+ open(p, 'wb').write(a.tobytes())
20
+ dump('x', x); dump('W', W); dump('y_ref', y_ref)
21
+ with open(os.path.join(out_dir, 'shape.txt'), 'w') as f:
22
+ f.write(f"N={N}\nD={D}\nOUT={OUT}\n")
23
+ print(f"N={N} D={D} OUT={OUT}, y_ref[0, :4] = {y_ref[0, :4].float().cpu().tolist()}")
scripts/gen_moe_reference.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Generate MoE layer forward reference for Qwen3-235B layer 0.
3
+
4
+ Input: hidden_states from attention output (use attn_data/final_out.bin as input — realistic).
5
+ Output: hidden_states after MoE + residual.
6
+ """
7
+ import os, json, math, torch, torch_npu
8
+ from safetensors.torch import load_file
9
+
10
+ torch.npu.set_device(0)
11
+ torch.set_grad_enabled(False)
12
+
13
+ MODEL_DIR = '/path/to/Qwen3-235B-A22B-Instruct-2507-BF16'
14
+ OUT_DIR = 'tests/moe_data'
15
+ os.makedirs(OUT_DIR, exist_ok=True)
16
+
17
+ cfg = json.load(open(os.path.join(MODEL_DIR, 'config.json')))
18
+ D = cfg['hidden_size'] # 4096
19
+ I = cfg['moe_intermediate_size'] # 1536
20
+ E = cfg['num_experts'] # 128
21
+ TK = cfg['num_experts_per_tok'] # 8
22
+ eps = cfg['rms_norm_eps']
23
+ norm_topk = cfg.get('norm_topk_prob', True)
24
+
25
+ # Use attention output as input (more realistic than random)
26
+ attn_out_raw = open('tests/attn_data/final_out.bin', 'rb').read()
27
+ S = 5
28
+ x_in = torch.frombuffer(bytearray(attn_out_raw), dtype=torch.int16).view(1, S, D).view(torch.bfloat16).npu()
29
+ print(f"x_in: {x_in.shape}")
30
+
31
+ # Load required weights for layer 0
32
+ idx = json.load(open(os.path.join(MODEL_DIR, 'model.safetensors.index.json')))
33
+ wm = idx['weight_map']
34
+
35
+ needed = [f'model.layers.0.post_attention_layernorm.weight',
36
+ f'model.layers.0.mlp.gate.weight']
37
+ for e in range(E):
38
+ for p in ['gate_proj', 'up_proj', 'down_proj']:
39
+ needed.append(f'model.layers.0.mlp.experts.{e}.{p}.weight')
40
+
41
+ shards = sorted({wm[n] for n in needed})
42
+ weights = {}
43
+ for sh in shards:
44
+ t = load_file(os.path.join(MODEL_DIR, sh))
45
+ for n in needed:
46
+ if n in t:
47
+ weights[n] = t[n].to('npu')
48
+ print("loaded %d tensors from %d shards" % (len(weights), len(shards)))
49
+
50
+ # Residual = input
51
+ residual = x_in
52
+
53
+ # Post-attention RmsNorm
54
+ xn, _ = torch_npu.npu_rms_norm(x_in, weights['model.layers.0.post_attention_layernorm.weight'], epsilon=eps)
55
+ xn_flat = xn.view(S, D) # flatten batch
56
+
57
+ # Router: logits [S, E]
58
+ W_router = weights['model.layers.0.mlp.gate.weight'] # [E, D]
59
+ logits = xn_flat @ W_router.t() # [S, E]
60
+
61
+ # Top-k softmax
62
+ topk_logits, topk_idx = logits.topk(TK, dim=-1) # both [S, TK]
63
+ topk_weights = torch.softmax(topk_logits.float(), dim=-1) # [S, TK] F32
64
+ if norm_topk:
65
+ topk_weights = topk_weights / (topk_weights.sum(dim=-1, keepdim=True) + 1e-20)
66
+ topk_weights = topk_weights.to(torch.bfloat16)
67
+ topk_idx = topk_idx.to(torch.int32)
68
+
69
+ print(f"topk_idx[0]: {topk_idx[0].cpu().tolist()}")
70
+ print(f"topk_weights[0]: {topk_weights[0].cpu().float().tolist()}")
71
+
72
+ # MoE forward — loop over tokens (simple reference, not optimized)
73
+ out_flat = torch.zeros(S, D, dtype=torch.bfloat16, device='npu')
74
+ for s in range(S):
75
+ token = xn_flat[s] # [D]
76
+ acc = torch.zeros(D, dtype=torch.bfloat16, device='npu')
77
+ for k in range(TK):
78
+ e = int(topk_idx[s, k].item())
79
+ w = topk_weights[s, k]
80
+ Wg = weights[f'model.layers.0.mlp.experts.{e}.gate_proj.weight'] # [I, D]
81
+ Wu = weights[f'model.layers.0.mlp.experts.{e}.up_proj.weight'] # [I, D]
82
+ Wd = weights[f'model.layers.0.mlp.experts.{e}.down_proj.weight'] # [D, I]
83
+ gate = token @ Wg.t() # [I]
84
+ up = token @ Wu.t()
85
+ act = torch.nn.functional.silu(gate) * up
86
+ down = act @ Wd.t() # [D]
87
+ acc = acc + w * down
88
+ out_flat[s] = acc
89
+
90
+ moe_out = out_flat.view(1, S, D)
91
+ final_out = residual + moe_out
92
+ print(f"final_out[0,0,:4] = {final_out[0,0,:4].float().cpu().tolist()}")
93
+
94
+ # Dump
95
+ def dump(name, t):
96
+ p = os.path.join(OUT_DIR, name + '.bin')
97
+ a = t.contiguous().cpu().view(torch.int16).numpy().astype('int16')
98
+ open(p, 'wb').write(a.tobytes())
99
+
100
+ dump('x_in', x_in)
101
+ dump('final_out', final_out)
102
+ dump('moe_out', moe_out)
103
+ dump('router', W_router)
104
+ dump('xn', xn)
105
+ dump('topk_w', topk_weights) # [S, TK] BF16 (normalized)
106
+ dump('out_flat', out_flat) # [S, D] BF16 — moe contrib before residual
107
+
108
+ # expert_idx as int32 dump (raw bytes)
109
+ topk_idx_bytes = topk_idx.contiguous().cpu().numpy().astype('int32').tobytes()
110
+ open(os.path.join(OUT_DIR, 'topk_idx.bin'), 'wb').write(topk_idx_bytes)
111
+
112
+ with open(os.path.join(OUT_DIR, 'shape.txt'), 'w') as f:
113
+ f.write(f"S={S}\nD={D}\nI={I}\nE={E}\nTK={TK}\n")
114
+
115
+ print(f"\nDumps in {OUT_DIR}")
scripts/gen_rms_norm_reference.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Generate a RmsNorm reference using PyTorch."""
3
+ import os, struct
4
+ import numpy as np
5
+ import torch
6
+ import torch_npu
7
+
8
+ torch.npu.set_device(0)
9
+ torch.manual_seed(123)
10
+
11
+ N, D = 5, 4096 # 5 tokens, Qwen3 hidden_size
12
+ eps = 1e-6
13
+
14
+ x = torch.randn(N, D, dtype=torch.bfloat16).npu()
15
+ gamma = torch.randn(D, dtype=torch.bfloat16).npu() * 0.1 + 1.0
16
+
17
+ # Use torch_npu's npu_rms_norm if available, else do it manually
18
+ y_ref, _ = torch_npu.npu_rms_norm(x, gamma, epsilon=eps)
19
+
20
+ out_dir = 'tests/rms_norm_data'
21
+ os.makedirs(out_dir, exist_ok=True)
22
+
23
+ def dump_bf16(name, t):
24
+ path = os.path.join(out_dir, name + '.bin')
25
+ a = t.contiguous().cpu().view(torch.int16).numpy().astype('int16')
26
+ with open(path, 'wb') as f:
27
+ f.write(a.tobytes())
28
+ return path
29
+
30
+ dump_bf16('x', x)
31
+ dump_bf16('gamma', gamma)
32
+ dump_bf16('y_ref', y_ref)
33
+
34
+ with open(os.path.join(out_dir, 'shape.txt'), 'w') as f:
35
+ f.write(f"N={N}\nD={D}\neps={eps}\n")
36
+
37
+ print(f"x shape: {x.shape}, gamma: {gamma.shape}, y_ref: {y_ref.shape}")
38
+ print("y_ref[0, :8]:", y_ref[0, :8].float().cpu().tolist())
39
+ print("saved in", out_dir)
scripts/regen_rope_reference.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Re-generate RoPE reference using explicit HF formula (not torch_npu.npu_apply_rotary_pos_emb)."""
3
+ import os, math, torch, torch_npu
4
+ torch.npu.set_device(0)
5
+ torch.manual_seed(42)
6
+
7
+ S = 5; Hq = 64; Hkv = 4; Dh = 128
8
+ theta = 5e6
9
+ data = 'tests/attn_data'
10
+
11
+ def load_bf16(name, shape):
12
+ raw = open(os.path.join(data, name + '.bin'), 'rb').read()
13
+ a = torch.frombuffer(bytearray(raw), dtype=torch.int16).view(*shape).view(torch.bfloat16)
14
+ return a.npu()
15
+
16
+ q = load_bf16('q_normed', [1, S, Hq, Dh])
17
+ k = load_bf16('k_normed', [1, S, Hkv, Dh])
18
+
19
+ # Compute cos/sin identical to HF (rope_theta=5e6, 0..S positions)
20
+ inv_freq = 1.0 / (theta ** (torch.arange(0, Dh, 2, dtype=torch.float32).npu() / Dh))
21
+ pos = torch.arange(S, device='npu').float().unsqueeze(-1)
22
+ freqs = pos * inv_freq
23
+ emb = torch.cat([freqs, freqs], dim=-1) # [S, Dh]
24
+ cos = emb.cos().to(torch.bfloat16) # [S, Dh]
25
+ sin = emb.sin().to(torch.bfloat16)
26
+
27
+ # HF (Qwen3) style RoPE: q_rot = q * cos + rotate_half(q) * sin
28
+ def rotate_half(x):
29
+ h = x.shape[-1] // 2
30
+ x1 = x[..., :h]
31
+ x2 = x[..., h:]
32
+ return torch.cat([-x2, x1], dim=-1)
33
+
34
+ # Broadcast cos/sin from [S, Dh] to [1, S, 1, Dh]
35
+ cos_b = cos.unsqueeze(0).unsqueeze(2)
36
+ sin_b = sin.unsqueeze(0).unsqueeze(2)
37
+
38
+ q_roped_hf = q * cos_b + rotate_half(q) * sin_b
39
+ k_roped_hf = k * cos_b + rotate_half(k) * sin_b
40
+
41
+ print("HF-style q_roped[0,0,:4]:", q_roped_hf[0,0,0,:4].float().cpu().tolist())
42
+ print("cos[0,:4]:", cos[0,:4].float().cpu().tolist())
43
+ print("sin[0,:4]:", sin[0,:4].float().cpu().tolist())
44
+ print("cos[1,:4]:", cos[1,:4].float().cpu().tolist())
45
+
46
+ # Compare with existing q_roped (from torch_npu.npu_apply_rotary_pos_emb)
47
+ old_q_roped = load_bf16('q_roped', [1, S, Hq, Dh])
48
+ diff = (q_roped_hf - old_q_roped).float().abs().max().item()
49
+ print(f"\nDiff between HF formula and npu_apply: max={diff:.4f}")
50
+
51
+ # Save HF version as ground truth
52
+ def dump(name, t):
53
+ p = os.path.join(data, name + '.bin')
54
+ a = t.contiguous().cpu().view(torch.int16).numpy().astype('int16')
55
+ open(p, 'wb').write(a.tobytes())
56
+ dump('q_roped', q_roped_hf)
57
+ dump('k_roped', k_roped_hf)
58
+ # Overwrite cos, sin to [1, S, Dh] layout
59
+ dump('cos', cos.unsqueeze(0)) # [1, S, Dh]
60
+ dump('sin', sin.unsqueeze(0))
61
+
62
+ print("\nOverwrote q_roped, k_roped, cos, sin with HF-formula ground truth.")
scripts/tp_launch.sh ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ # tp_launch.sh — launcher for TP>1 multi-process qwen3-moe-aclnn.
3
+ #
4
+ # Usage: ./tp_launch.sh <tp_size> <bin> [args...]
5
+ # e.g. ./tp_launch.sh 16 ./build/qwen3-moe-aclnn --model-dir ... --prompt "..." --n-predict 20
6
+ #
7
+ # Each rank runs as a separate process with:
8
+ # ASCEND_RT_VISIBLE_DEVICES=<rank>
9
+ # TP_RANK=<rank> TP_SIZE=<tp_size>
10
+ # HCCL_WHITELIST_DISABLE=1
11
+ # rank 0 creates /tmp/hccl_root_info.bin; other ranks wait for it.
12
+ set -euo pipefail
13
+
14
+ TP_SIZE="${1:?tp_size required}"; shift
15
+ BIN="${1:?binary required}"; shift
16
+
17
+ # Clean any stale HCCL coordination file
18
+ rm -f /tmp/hccl_root_info.bin
19
+
20
+ export HCCL_WHITELIST_DISABLE=1
21
+ # Benchmark-tuned defaults (bench_hccl_adv.sh 2026-04-21):
22
+ # ring:200 + OP_EXPANSION_MODE=AIV + OP_BASE_FFTS_MODE_ENABLE=1 → ~18.8 t/s median
23
+ # vs baseline (auto) ~12 t/s. +54% from HCCL env knobs alone.
24
+ export HCCL_ALGO="${HCCL_ALGO:-level0:ring}"
25
+ export HCCL_BUFFSIZE="${HCCL_BUFFSIZE:-200}"
26
+ export HCCL_OP_EXPANSION_MODE="${HCCL_OP_EXPANSION_MODE:-AIV}"
27
+ export HCCL_OP_BASE_FFTS_MODE_ENABLE="${HCCL_OP_BASE_FFTS_MODE_ENABLE:-1}"
28
+ # TASK_QUEUE_ENABLE=2: aggressive async task queueing (marginal gain on top of AIV+FFTS)
29
+ export TASK_QUEUE_ENABLE="${TASK_QUEUE_ENABLE:-2}"
30
+
31
+ # Launch ranks 1..N-1 in background with stdin/stdout redirected to /dev/null / logfile.
32
+ # Launch rank 0 LAST in foreground, inheriting the terminal stdin/stdout — so --interactive works.
33
+ pids=()
34
+ for rank in $(seq 1 $((TP_SIZE - 1))); do
35
+ logfile="/tmp/tp_rank_${rank}.log"
36
+ env ASCEND_RT_VISIBLE_DEVICES=${rank} \
37
+ TP_RANK=${rank} \
38
+ TP_SIZE=${TP_SIZE} \
39
+ "${BIN}" "$@" < /dev/null > "${logfile}" 2>&1 &
40
+ pids+=($!)
41
+ echo "[tp_launch] rank ${rank} pid=$! log=${logfile}"
42
+ done
43
+
44
+ # Give ranks 1..N-1 a moment to reach HcclCommInitRootInfo's file-wait before rank 0 writes it.
45
+ sleep 1
46
+
47
+ # Rank 0 in foreground — terminal stdin/stdout passthrough for REPL.
48
+ env ASCEND_RT_VISIBLE_DEVICES=0 \
49
+ TP_RANK=0 \
50
+ TP_SIZE=${TP_SIZE} \
51
+ "${BIN}" "$@"
52
+ ec=$?
53
+
54
+ # Wait for background ranks to finish (rank 0 exit signals end-of-work, but they may take a bit).
55
+ for i in "${!pids[@]}"; do
56
+ wait "${pids[$i]}" || true
57
+ done
58
+ exit $ec
src/device_weights.cpp ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "device_weights.h"
2
+ #include "aclnn_ops.h"
3
+ #include <cstdio>
4
+ #include <cstring>
5
+ #include <vector>
6
+
7
+ bool DeviceWeightsLoader::load_tensor_full_(const std::string& name, DeviceBuffer& buf) {
8
+ const auto* m = st_.get(name);
9
+ if (!m) { fprintf(stderr, "load_tensor_full_: missing %s\n", name.c_str()); return false; }
10
+ const void* host = st_.data_ptr(*m);
11
+ if (!host) { fprintf(stderr, "load_tensor_full_: null host ptr %s\n", name.c_str()); return false; }
12
+ buf.alloc(m->nbytes);
13
+ ACL_CHECK(aclrtMemcpy(buf.get(), m->nbytes, host, m->nbytes, ACL_MEMCPY_HOST_TO_DEVICE));
14
+ return true;
15
+ }
16
+
17
+ bool DeviceWeightsLoader::load_tensor_row_slice_(const std::string& name,
18
+ int64_t row_lo, int64_t row_hi,
19
+ DeviceBuffer& buf) {
20
+ const auto* m = st_.get(name);
21
+ if (!m) { fprintf(stderr, "load_tensor_row_slice_: missing %s\n", name.c_str()); return false; }
22
+ if (m->shape.empty()) { fprintf(stderr, "%s: empty shape\n", name.c_str()); return false; }
23
+ int64_t D0 = m->shape[0];
24
+ if (row_hi > D0 || row_lo < 0 || row_hi <= row_lo) {
25
+ fprintf(stderr, "load_tensor_row_slice_: %s bad range [%ld,%ld) vs D0=%ld\n",
26
+ name.c_str(), row_lo, row_hi, D0);
27
+ return false;
28
+ }
29
+ size_t elem = sdtype_size(m->dtype);
30
+ size_t inner = 1;
31
+ for (size_t i = 1; i < m->shape.size(); i++) inner *= m->shape[i];
32
+ size_t row_bytes = inner * elem;
33
+ size_t slice_bytes = (row_hi - row_lo) * row_bytes;
34
+
35
+ const auto* host = (const char*)st_.data_ptr(*m);
36
+ buf.alloc(slice_bytes);
37
+ ACL_CHECK(aclrtMemcpy(buf.get(), slice_bytes,
38
+ host + row_lo * row_bytes, slice_bytes,
39
+ ACL_MEMCPY_HOST_TO_DEVICE));
40
+ return true;
41
+ }
42
+
43
+ bool DeviceWeightsLoader::load_tensor_col_slice_(const std::string& name,
44
+ int64_t col_lo, int64_t col_hi,
45
+ DeviceBuffer& buf) {
46
+ const auto* m = st_.get(name);
47
+ if (!m || m->shape.size() < 2) {
48
+ fprintf(stderr, "load_tensor_col_slice_: bad shape %s\n", name.c_str()); return false;
49
+ }
50
+ int64_t D0 = m->shape[0];
51
+ int64_t D1 = m->shape[1];
52
+ if (col_hi > D1 || col_lo < 0 || col_hi <= col_lo) {
53
+ fprintf(stderr, "load_tensor_col_slice_: bad range %ld-%ld D1=%ld\n",
54
+ col_lo, col_hi, D1); return false;
55
+ }
56
+ size_t elem = sdtype_size(m->dtype);
57
+ int64_t new_cols = col_hi - col_lo;
58
+ size_t slice_bytes = D0 * new_cols * elem;
59
+ buf.alloc(slice_bytes);
60
+
61
+ // Need to copy row-by-row since source has stride D1 but dest has stride new_cols.
62
+ const auto* host = (const char*)st_.data_ptr(*m);
63
+ std::vector<char> staging(slice_bytes);
64
+ size_t src_row = D1 * elem;
65
+ size_t dst_row = new_cols * elem;
66
+ size_t col_off = col_lo * elem;
67
+ for (int64_t r = 0; r < D0; r++) {
68
+ std::memcpy(staging.data() + r * dst_row, host + r * src_row + col_off, dst_row);
69
+ }
70
+ ACL_CHECK(aclrtMemcpy(buf.get(), slice_bytes, staging.data(), slice_bytes,
71
+ ACL_MEMCPY_HOST_TO_DEVICE));
72
+ return true;
73
+ }
74
+
75
+ bool DeviceWeightsLoader::load_shared(SharedWeights& out) {
76
+ if (!load_tensor_full_("model.embed_tokens.weight", out.embed_tokens)) return false;
77
+ if (!load_tensor_full_("lm_head.weight", out.lm_head)) return false;
78
+ if (!load_tensor_full_("model.norm.weight", out.final_norm)) return false;
79
+ return true;
80
+ }
81
+
82
+ bool DeviceWeightsLoader::load_moe(int L, aclrtStream stream, LayerMoEWeights& out) {
83
+ const int64_t E = cfg_.num_experts;
84
+ const int64_t D = cfg_.hidden_size;
85
+ const int64_t I_full = cfg_.moe_intermediate_size;
86
+ const int64_t I_rank = cfg_.i_per_rank;
87
+ const size_t elem = 2; // BF16
88
+
89
+ auto base = "model.layers." + std::to_string(L);
90
+
91
+ // 1. Router [E, D] — small, fully replicated
92
+ if (!load_tensor_full_(base + ".mlp.gate.weight", out.router)) return false;
93
+
94
+ // 2. MoE expert weights: need to stack 128 experts + TP slice + permute
95
+ // HF gate/up: each expert [I_full, D] → TP slice rows to [I_rank, D]
96
+ // HF down: each expert [D, I_full] → TP slice cols to [D, I_rank]
97
+
98
+ auto load_and_stack = [&](const std::string& proj_name,
99
+ bool is_down, DeviceBuffer& final_buf) -> bool {
100
+ // HF shape for gate/up: [I_full, D]; for down: [D, I_full]
101
+ // After TP slice: gate/up rows [I_rank, D]; down cols [D, I_rank]
102
+ // Stacked:
103
+ // gate/up: [E, I_rank, D] → permute to [E, D, I_rank]
104
+ // down: [E, D, I_rank] → permute to [E, I_rank, D]
105
+
106
+ int64_t K_in, N_out;
107
+ bool row_slice;
108
+ if (!is_down) {
109
+ K_in = I_rank; // HF first dim after row-slice
110
+ N_out = D;
111
+ row_slice = true;
112
+ } else {
113
+ K_in = D;
114
+ N_out = I_rank;
115
+ row_slice = false; // col slice
116
+ }
117
+
118
+ // Stage: stacked HF-layout [E, K_in, N_out] on device (before permute)
119
+ size_t elem_stack = K_in * N_out * elem;
120
+ DeviceBuffer stacked_hf(E * elem_stack);
121
+
122
+ // For each expert, load + TP slice + memcpy to stacked_hf[e]
123
+ // We use the existing row_slice/col_slice helpers on a per-expert basis.
124
+ DeviceBuffer tmp;
125
+ for (int e = 0; e < E; e++) {
126
+ std::string name = base + ".mlp.experts." + std::to_string(e) + "." + proj_name + ".weight";
127
+ if (row_slice) {
128
+ int64_t lo = cfg_.tp_rank * I_rank;
129
+ int64_t hi = lo + I_rank;
130
+ if (!load_tensor_row_slice_(name, lo, hi, tmp)) return false;
131
+ } else {
132
+ int64_t lo = cfg_.tp_rank * I_rank;
133
+ int64_t hi = lo + I_rank;
134
+ if (!load_tensor_col_slice_(name, lo, hi, tmp)) return false;
135
+ }
136
+ if (tmp.size != elem_stack) {
137
+ fprintf(stderr, "load_moe: expert %d %s slice size %zu != expected %zu\n",
138
+ e, name.c_str(), tmp.size, elem_stack);
139
+ return false;
140
+ }
141
+ // Synchronous D2D: tmp is about to be reallocated in the next iteration,
142
+ // so we cannot enqueue an async copy that would still reference it.
143
+ ACL_CHECK(aclrtMemcpy(
144
+ (char*)stacked_hf.get() + e * elem_stack, elem_stack,
145
+ tmp.get(), elem_stack,
146
+ ACL_MEMCPY_DEVICE_TO_DEVICE));
147
+ }
148
+
149
+ // Now permute stacked_hf [E, K_in, N_out] → final [E, N_out, K_in] row-major
150
+ // (swap last two dims)
151
+ final_buf.alloc(E * elem_stack);
152
+ const int64_t d0 = E, d1 = K_in, d2 = N_out;
153
+ // View stacked_hf with permuted strides pointing to same data:
154
+ // logical shape [E, N_out, K_in], strides [K_in*N_out, 1, N_out]
155
+ // (since physical is [E, K_in, N_out] row-major with strides [K_in*N_out, N_out, 1])
156
+ auto t_src = make_acl_tensor(stacked_hf.get(), ACL_BF16,
157
+ {d0, d2, d1}, // [E, N_out, K_in]
158
+ {d1 * d2, 1, d2});
159
+ auto t_dst = make_contig_tensor(final_buf.get(), ACL_BF16, {d0, d2, d1});
160
+ inplace_copy(stream, t_dst.get(), t_src.get());
161
+ // Must sync before stacked_hf goes out of scope — the inplace_copy is async and
162
+ // reads from stacked_hf's memory. If we return without syncing, DeviceBuffer's
163
+ // destructor frees stacked_hf while the permute kernel is still running, producing
164
+ // garbage in final_buf.
165
+ ACL_CHECK(aclrtSynchronizeStream(stream));
166
+ return true;
167
+ };
168
+
169
+ if (!load_and_stack("gate_proj", false, out.gate_exps)) return false;
170
+ if (!load_and_stack("up_proj", false, out.up_exps)) return false;
171
+ if (!load_and_stack("down_proj", true, out.down_exps)) return false;
172
+
173
+ return true;
174
+ }
175
+
176
+ bool DeviceWeightsLoader::load_attention(int L, LayerAttnWeights& out) {
177
+ auto base = "model.layers." + std::to_string(L);
178
+
179
+ if (!load_tensor_full_(base + ".input_layernorm.weight", out.input_layernorm)) return false;
180
+ if (!load_tensor_full_(base + ".post_attention_layernorm.weight", out.post_attention_layernorm)) return false;
181
+ if (!load_tensor_full_(base + ".self_attn.q_norm.weight", out.q_norm)) return false;
182
+ if (!load_tensor_full_(base + ".self_attn.k_norm.weight", out.k_norm)) return false;
183
+
184
+ const int64_t head_dim = cfg_.head_dim;
185
+ const int64_t q_full = cfg_.num_attention_heads * head_dim; // 64 * 128 = 8192
186
+
187
+ // q_proj: [q_full, D], shard rows by head. Each rank gets n_heads_per_rank heads.
188
+ int64_t q_rows_per_rank = cfg_.n_heads_per_rank * head_dim;
189
+ int64_t q_row_lo = cfg_.tp_rank * q_rows_per_rank;
190
+ int64_t q_row_hi = q_row_lo + q_rows_per_rank;
191
+ if (!load_tensor_row_slice_(base + ".self_attn.q_proj.weight",
192
+ q_row_lo, q_row_hi, out.q_proj)) return false;
193
+
194
+ // k_proj, v_proj: HF shape [num_kv * head_dim, D].
195
+ // Case A (tp <= n_kv): split rows across ranks, each rank gets n_kv/tp KV heads.
196
+ // Case B (tp > n_kv): each rank gets exactly ONE KV head; group of (tp/n_kv) ranks share it.
197
+ // kv_head_idx = tp_rank / (tp_size / n_kv)
198
+ if (cfg_.tp_size <= cfg_.num_key_value_heads) {
199
+ int64_t kv_rows_per_rank = cfg_.n_kv_heads_per_rank * head_dim;
200
+ int64_t kv_row_lo = cfg_.tp_rank * kv_rows_per_rank;
201
+ int64_t kv_row_hi = kv_row_lo + kv_rows_per_rank;
202
+ if (!load_tensor_row_slice_(base + ".self_attn.k_proj.weight", kv_row_lo, kv_row_hi, out.k_proj)) return false;
203
+ if (!load_tensor_row_slice_(base + ".self_attn.v_proj.weight", kv_row_lo, kv_row_hi, out.v_proj)) return false;
204
+ } else {
205
+ // GQA replicated-group mode: 1 KV head per rank, selected by group.
206
+ int64_t ranks_per_kv = cfg_.tp_size / cfg_.num_key_value_heads;
207
+ int64_t kv_head_idx = cfg_.tp_rank / ranks_per_kv;
208
+ int64_t kv_row_lo = kv_head_idx * head_dim;
209
+ int64_t kv_row_hi = kv_row_lo + head_dim;
210
+ if (!load_tensor_row_slice_(base + ".self_attn.k_proj.weight", kv_row_lo, kv_row_hi, out.k_proj)) return false;
211
+ if (!load_tensor_row_slice_(base + ".self_attn.v_proj.weight", kv_row_lo, kv_row_hi, out.v_proj)) return false;
212
+ }
213
+
214
+ // o_proj: [D, q_full], row-parallel → shard cols (input dim) by head.
215
+ int64_t o_col_lo = q_row_lo; // same slicing as q rows
216
+ int64_t o_col_hi = q_row_hi;
217
+ if (!load_tensor_col_slice_(base + ".self_attn.o_proj.weight",
218
+ o_col_lo, o_col_hi, out.o_proj)) return false;
219
+
220
+ return true;
221
+ }
src/main_cli.cpp ADDED
@@ -0,0 +1,816 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // main_cli.cpp — qwen3-moe-aclnn entry point.
2
+ //
3
+ // Usage:
4
+ // qwen3-moe-aclnn --model-dir <path> --prompt "<text>" --n-predict <N>
5
+ // [--tp-size 1|16] [--vocab <path>] [--max-seq N] [--num-layers N]
6
+ // [--chat] [--temperature 0.7] [--top-k 20] [--top-p 0.8] [--seed N]
7
+ // [--no-stream]
8
+ //
9
+ // At TP>1 each rank is a separate process (env TP_RANK=<i>, TP_SIZE=<n>) launched by
10
+ // scripts/tp_launch.sh. Only rank 0 prints text output.
11
+ #include "runner.h"
12
+ #include "tokenizer.h"
13
+
14
+ // Escape hatch for HCCL broadcast from within CLI (defined in runner.cpp)
15
+ HcclCtx* runner_hccl_ctx_shim(Runner& r);
16
+
17
+ #include <algorithm>
18
+ #include <chrono>
19
+ #include <cmath>
20
+ #include <cstdio>
21
+ #include <cstdlib>
22
+ #include <cstring>
23
+ #include <iostream>
24
+ #include <random>
25
+ #include <string>
26
+ #include <vector>
27
+
28
+ static float bf16_to_float(uint16_t x) {
29
+ uint32_t u = (uint32_t)x << 16; float f; std::memcpy(&f, &u, 4); return f;
30
+ }
31
+
32
+ // Truncate a string to the last complete UTF-8 character boundary. If the last 1-3 bytes
33
+ // form an incomplete multi-byte sequence (e.g., assistant response cut mid-codepoint at
34
+ // n_predict limit), drop them so the JSON encoder downstream sees only valid UTF-8.
35
+ static std::string utf8_trim_incomplete(const std::string& s) {
36
+ if (s.empty()) return s;
37
+ size_t n = s.size();
38
+ // Walk back up to 4 bytes looking for the start of a UTF-8 sequence.
39
+ for (size_t back = 0; back < 4 && back < n; back++) {
40
+ size_t i = n - 1 - back;
41
+ unsigned char c = (unsigned char)s[i];
42
+ if ((c & 0x80) == 0) { return s; } // ASCII: already complete
43
+ if ((c & 0xC0) == 0x80) { continue; } // continuation byte: keep going
44
+ // Start byte: 110xxxxx (2-byte), 1110xxxx (3-byte), 11110xxx (4-byte)
45
+ size_t need = 0;
46
+ if ((c & 0xE0) == 0xC0) need = 2;
47
+ else if ((c & 0xF0) == 0xE0) need = 3;
48
+ else if ((c & 0xF8) == 0xF0) need = 4;
49
+ else return s.substr(0, i); // invalid start — drop
50
+ size_t have = back + 1;
51
+ return (have >= need) ? s : s.substr(0, i); // trim incomplete trailing sequence
52
+ }
53
+ // Should not reach here; return as-is.
54
+ return s;
55
+ }
56
+
57
+ struct Args {
58
+ std::string model_dir;
59
+ std::string prompt = "The capital of France is";
60
+ std::string vocab_path = "tokenizer_data/vocab.bin";
61
+ int n_predict = 100;
62
+ int tp_size = 1;
63
+ int tp_rank = 0;
64
+ int num_layers = 0; // 0 = auto
65
+ int max_seq = 512;
66
+ int device_id = 0;
67
+ bool chat_template = false;
68
+ bool stream = true;
69
+ bool interactive = false;
70
+ bool reset_each_turn = false; // if true, REPL clears KV cache between turns (stateless)
71
+ std::string system_prompt; // optional system role for chat mode
72
+ std::string prompt_file; // read prompt from file (avoids shell escaping)
73
+ bool pld_enabled = false; // prompt lookup decoding
74
+ int pld_k = 10; // bench_pld_k.sh: K=10 median 105 t/s (3/3 runs 100+), K=8 was 35
75
+ int pld_ngram = 1; // n-gram match size — 1 with multi-level fallback best
76
+ bool pld_adaptive = false; // fixed K=10 is simpler and mean-optimal; adaptive --pld-adaptive
77
+ int pld_min_hist = 20; // skip PLD until history >= this (avoid early-token false matches)
78
+ // PLD degeneration guard (on by default): prevents PLD from amplifying repetition loops.
79
+ bool pld_guard = true; // --pld-no-guard disables
80
+ int pld_guard_distinct = 3; // reject draft if distinct tokens < this (≥K/3 heuristic)
81
+ int pld_guard_tail = 6; // reject if draft[0] matches all last N hist tokens
82
+ int pld_loop_warn = 8; // warn once when N consecutive identical tokens emitted
83
+ float temperature = 0.0f; // 0 = greedy
84
+ int top_k = 0; // 0 = disabled
85
+ float top_p = 1.0f; // 1.0 = disabled
86
+ uint64_t seed = 0; // 0 = use time
87
+ // Qwen3 EOS tokens (from generation_config.json)
88
+ std::vector<int> eos_ids = {151645, 151643};
89
+ };
90
+
91
+ static bool parse_args(int argc, char** argv, Args& a) {
92
+ for (int i = 1; i < argc; i++) {
93
+ std::string s = argv[i];
94
+ auto next = [&](const char* f)->const char* {
95
+ if (i + 1 >= argc) { fprintf(stderr, "missing value for %s\n", f); return nullptr; }
96
+ return argv[++i];
97
+ };
98
+ if (s == "--model-dir") { auto v = next(s.c_str()); if (!v) return false; a.model_dir = v; }
99
+ else if (s == "--prompt") { auto v = next(s.c_str()); if (!v) return false; a.prompt = v; }
100
+ else if (s == "--vocab") { auto v = next(s.c_str()); if (!v) return false; a.vocab_path = v; }
101
+ else if (s == "--n-predict") { auto v = next(s.c_str()); if (!v) return false; a.n_predict = std::atoi(v); }
102
+ else if (s == "--tp-size") { auto v = next(s.c_str()); if (!v) return false; a.tp_size = std::atoi(v); }
103
+ else if (s == "--num-layers") { auto v = next(s.c_str()); if (!v) return false; a.num_layers = std::atoi(v); }
104
+ else if (s == "--max-seq") { auto v = next(s.c_str()); if (!v) return false; a.max_seq = std::atoi(v); }
105
+ else if (s == "--device") { auto v = next(s.c_str()); if (!v) return false; a.device_id = std::atoi(v); }
106
+ else if (s == "--temperature") { auto v = next(s.c_str()); if (!v) return false; a.temperature = (float)std::atof(v); }
107
+ else if (s == "--top-k") { auto v = next(s.c_str()); if (!v) return false; a.top_k = std::atoi(v); }
108
+ else if (s == "--top-p") { auto v = next(s.c_str()); if (!v) return false; a.top_p = (float)std::atof(v); }
109
+ else if (s == "--seed") { auto v = next(s.c_str()); if (!v) return false; a.seed = (uint64_t)std::atoll(v); }
110
+ else if (s == "--chat") { a.chat_template = true; }
111
+ else if (s == "--no-stream") { a.stream = false; }
112
+ else if (s == "--interactive" || s == "-i") { a.interactive = true; }
113
+ else if (s == "--reset") { a.reset_each_turn = true; }
114
+ else if (s == "--system") { auto v = next(s.c_str()); if (!v) return false; a.system_prompt = v; }
115
+ else if (s == "--prompt-file") { auto v = next(s.c_str()); if (!v) return false; a.prompt_file = v; }
116
+ else if (s == "--pld") { a.pld_enabled = true; }
117
+ else if (s == "--pld-k") { auto v = next(s.c_str()); if (!v) return false; a.pld_k = std::atoi(v); }
118
+ else if (s == "--pld-ngram") { auto v = next(s.c_str()); if (!v) return false; a.pld_ngram = std::atoi(v); }
119
+ else if (s == "--pld-adaptive"){ a.pld_adaptive = true; }
120
+ else if (s == "--pld-fixed-k") { a.pld_adaptive = false; } // opt out of adaptive
121
+ else if (s == "--pld-min-hist"){ auto v = next(s.c_str()); if (!v) return false; a.pld_min_hist = std::atoi(v); }
122
+ else if (s == "--pld-no-guard"){ a.pld_guard = false; }
123
+ else if (s == "--pld-guard-distinct"){ auto v = next(s.c_str()); if (!v) return false; a.pld_guard_distinct = std::atoi(v); }
124
+ else if (s == "--pld-guard-tail"){ auto v = next(s.c_str()); if (!v) return false; a.pld_guard_tail = std::atoi(v); }
125
+ else if (s == "--pld-loop-warn"){ auto v = next(s.c_str()); if (!v) return false; a.pld_loop_warn = std::atoi(v); }
126
+ else if (s == "--help" || s == "-h") {
127
+ printf("Usage: %s --model-dir <path> [options]\n", argv[0]);
128
+ printf(" --prompt \"text\" prompt text (default: \"%s\")\n", a.prompt.c_str());
129
+ printf(" --prompt-file FILE read prompt from file (overrides --prompt)\n");
130
+ printf(" --n-predict N max tokens to generate (default: %d)\n", a.n_predict);
131
+ printf(" --tp-size N tensor parallelism (default: 1; or TP_SIZE env)\n");
132
+ printf(" --num-layers N limit layers, testing only (default: all)\n");
133
+ printf(" --max-seq N KV cache + context cap (default: %d)\n", a.max_seq);
134
+ printf(" --chat apply Qwen3 chat template\n");
135
+ printf(" --system \"text\" system role for chat\n");
136
+ printf(" --temperature F 0 = greedy; typical 0.7\n");
137
+ printf(" --top-k N 0 = disabled\n");
138
+ printf(" --top-p F 1.0 = disabled; typical 0.8\n");
139
+ printf(" --seed N 0 = time-based (default)\n");
140
+ printf(" --no-stream batch-print final text\n");
141
+ printf(" -i, --interactive REPL (multi-turn memory when --chat)\n");
142
+ printf(" --reset force stateless REPL (reset each turn)\n");
143
+ printf(" --pld enable Prompt Lookup Decoding (greedy only)\n");
144
+ printf(" --pld-k N draft window size (default: 4)\n");
145
+ printf(" --pld-ngram N match n-gram size (default: 2; multi-level fallback)\n");
146
+ printf(" --pld-adaptive adjust K based on recent accept rate\n");
147
+ printf(" --pld-min-hist N skip PLD until history >= N tokens (default: 20)\n");
148
+ printf(" --pld-no-guard disable degeneration guard (dangerous: can amplify loops)\n");
149
+ printf(" --pld-guard-distinct N reject draft with distinct tokens < N (default: 3)\n");
150
+ printf(" --pld-guard-tail N reject draft if draft[0] matches all last N hist (default: 6)\n");
151
+ printf(" --pld-loop-warn N warn once on N consecutive identical emitted tokens (default: 8)\n");
152
+ return false;
153
+ }
154
+ else { fprintf(stderr, "unknown arg: %s\n", s.c_str()); return false; }
155
+ }
156
+ if (a.model_dir.empty()) { fprintf(stderr, "--model-dir required\n"); return false; }
157
+ if (const char* r = std::getenv("TP_RANK")) a.tp_rank = std::atoi(r);
158
+ if (const char* s = std::getenv("TP_SIZE")) a.tp_size = std::atoi(s);
159
+ return true;
160
+ }
161
+
162
+ // Sample next token from logits. temperature=0 → greedy argmax. Otherwise top-k / top-p.
163
+ static int sample_token(const std::vector<uint16_t>& logits_bf16, int64_t V,
164
+ float temperature, int top_k, float top_p, std::mt19937& rng) {
165
+ if (temperature <= 0.0f) {
166
+ int best = 0;
167
+ float bv = bf16_to_float(logits_bf16[0]);
168
+ for (int64_t i = 1; i < V; i++) {
169
+ float v = bf16_to_float(logits_bf16[i]);
170
+ if (v > bv) { bv = v; best = (int)i; }
171
+ }
172
+ return best;
173
+ }
174
+
175
+ // Build (logit, id) list as float
176
+ std::vector<std::pair<float, int>> scored;
177
+ scored.reserve(V);
178
+ for (int64_t i = 0; i < V; i++) {
179
+ scored.emplace_back(bf16_to_float(logits_bf16[i]) / temperature, (int)i);
180
+ }
181
+
182
+ // Top-k: keep highest k entries (partial sort)
183
+ if (top_k > 0 && top_k < (int)scored.size()) {
184
+ std::nth_element(scored.begin(), scored.begin() + top_k, scored.end(),
185
+ [](const auto& a, const auto& b){ return a.first > b.first; });
186
+ scored.resize(top_k);
187
+ }
188
+
189
+ // Sort descending for top-p
190
+ std::sort(scored.begin(), scored.end(),
191
+ [](const auto& a, const auto& b){ return a.first > b.first; });
192
+
193
+ // Softmax (numerically stable)
194
+ float maxv = scored[0].first;
195
+ double sum = 0;
196
+ for (auto& p : scored) { p.first = std::exp(p.first - maxv); sum += p.first; }
197
+ for (auto& p : scored) p.first /= (float)sum;
198
+
199
+ // Top-p nucleus
200
+ if (top_p > 0.0f && top_p < 1.0f) {
201
+ double cum = 0;
202
+ size_t cutoff = scored.size();
203
+ for (size_t i = 0; i < scored.size(); i++) {
204
+ cum += scored[i].first;
205
+ if (cum >= top_p) { cutoff = i + 1; break; }
206
+ }
207
+ scored.resize(cutoff);
208
+ // re-normalize
209
+ double s = 0; for (auto& p : scored) s += p.first;
210
+ for (auto& p : scored) p.first /= (float)s;
211
+ }
212
+
213
+ // Sample
214
+ std::uniform_real_distribution<float> U(0.0f, 1.0f);
215
+ float r = U(rng), acc = 0.0f;
216
+ for (auto& p : scored) {
217
+ acc += p.first;
218
+ if (r <= acc) return p.second;
219
+ }
220
+ return scored.back().second;
221
+ }
222
+
223
+ // Broadcast a prompt's token_ids from rank 0 to all ranks. For TP>1 the non-master ranks need
224
+ // the tokens before prefill. We use HCCL broadcast: rank 0 provides the count, then the ids.
225
+ // Uses a pre-allocated device buffer (must be large enough for max_seq tokens).
226
+ static bool broadcast_token_ids(Runner& runner, std::vector<int32_t>& ids,
227
+ int64_t max_seq, bool is_master) {
228
+ const ModelConfig& cfg = runner.cfg();
229
+ if (cfg.tp_size <= 1) return true;
230
+
231
+ // Step 1: broadcast count (as int32 on device)
232
+ DeviceBuffer cnt_dev(4);
233
+ int32_t cnt = is_master ? (int32_t)ids.size() : 0;
234
+ ACL_CHECK(aclrtMemcpy(cnt_dev.get(), 4, &cnt, 4, ACL_MEMCPY_HOST_TO_DEVICE));
235
+ // Access Runner's HCCL context via stream (exposed) and rely on the fact that ctx.comm is owned.
236
+ // Since hccl_broadcast needs HcclCtx, we need access. Cheapest: friend access via a shim member.
237
+ // For now, Runner has a stream() accessor; HCCL ctx is private. We'll accept that and broadcast
238
+ // via a direct call on the comm — but ctx is hidden. Workaround: expose hccl_ctx() on Runner.
239
+ // ... (see Runner::hccl_ctx() accessor added below)
240
+ extern HcclCtx* runner_hccl_ctx_shim(Runner& r); // forward from runner.cpp
241
+ HcclCtx* ctx = runner_hccl_ctx_shim(runner);
242
+ if (!ctx) return false;
243
+
244
+ if (!hccl_broadcast(*ctx, cnt_dev.get(), 1, HCCL_DATA_TYPE_INT32, 0, runner.stream())) return false;
245
+ ACL_CHECK(aclrtSynchronizeStream(runner.stream()));
246
+ ACL_CHECK(aclrtMemcpy(&cnt, 4, cnt_dev.get(), 4, ACL_MEMCPY_DEVICE_TO_HOST));
247
+ if (cnt <= 0 || cnt > (int32_t)max_seq) {
248
+ fprintf(stderr, "[rank %d] broadcast: bad count %d\n", cfg.tp_rank, cnt);
249
+ return false;
250
+ }
251
+
252
+ // Step 2: broadcast the id buffer
253
+ DeviceBuffer ids_dev(cnt * 4);
254
+ if (is_master) {
255
+ ACL_CHECK(aclrtMemcpy(ids_dev.get(), cnt*4, ids.data(), cnt*4, ACL_MEMCPY_HOST_TO_DEVICE));
256
+ }
257
+ if (!hccl_broadcast(*ctx, ids_dev.get(), cnt, HCCL_DATA_TYPE_INT32, 0, runner.stream())) return false;
258
+ ACL_CHECK(aclrtSynchronizeStream(runner.stream()));
259
+ if (!is_master) {
260
+ ids.resize(cnt);
261
+ ACL_CHECK(aclrtMemcpy(ids.data(), cnt*4, ids_dev.get(), cnt*4, ACL_MEMCPY_DEVICE_TO_HOST));
262
+ }
263
+ return true;
264
+ }
265
+
266
+ // Run one generation turn. Assumes KV cache is reset. Returns perf summary.
267
+ struct TurnStats {
268
+ double prefill_ms = 0; double decode_ms = 0;
269
+ int n_prompt = 0; int decoded = 0; bool hit_eos = false;
270
+ };
271
+ static TurnStats run_turn(Runner& runner, Tokenizer& tokenizer, const Args& args,
272
+ const std::string& prompt, std::mt19937& rng, bool is_master) {
273
+ TurnStats st;
274
+
275
+ // --- Tokenize (on master; broadcast for TP>1) ---
276
+ std::vector<int32_t> input_ids;
277
+ if (is_master) {
278
+ auto raw = tokenizer.encode_via_python(args.model_dir, prompt, args.chat_template);
279
+ if (raw.empty()) return st;
280
+ input_ids.reserve(raw.size());
281
+ for (int v : raw) input_ids.push_back((int32_t)v);
282
+ }
283
+ if (args.tp_size > 1) {
284
+ if (!broadcast_token_ids(runner, input_ids, args.max_seq, is_master)) return st;
285
+ }
286
+ if (input_ids.empty()) return st;
287
+
288
+ const int64_t V = runner.cfg().vocab_size;
289
+ std::vector<uint16_t> logits_h(V);
290
+ auto load_logits = [&](DeviceBuffer& buf) {
291
+ ACL_CHECK(aclrtMemcpy(logits_h.data(), V*2, buf.get(), V*2, ACL_MEMCPY_DEVICE_TO_HOST));
292
+ };
293
+ auto is_eos = [&](int id) {
294
+ for (int e : args.eos_ids) if (id == e) return true;
295
+ return false;
296
+ };
297
+
298
+ // --- Prefill ---
299
+ st.n_prompt = (int)input_ids.size();
300
+ auto t0 = std::chrono::steady_clock::now();
301
+ DeviceBuffer logits;
302
+ if (!runner.prefill(input_ids.data(), (int64_t)input_ids.size(), logits)) return st;
303
+ auto t1 = std::chrono::steady_clock::now();
304
+ st.prefill_ms = std::chrono::duration<double, std::milli>(t1 - t0).count();
305
+
306
+ load_logits(logits);
307
+ int next_id = sample_token(logits_h, V, args.temperature, args.top_k, args.top_p, rng);
308
+
309
+ if (is_master && args.stream) {
310
+ if (!args.chat_template) printf("%s", prompt.c_str());
311
+ printf("%s", tokenizer.decode(next_id).c_str());
312
+ fflush(stdout);
313
+ }
314
+
315
+ std::vector<int> generated = { next_id };
316
+ st.hit_eos = is_eos(next_id);
317
+
318
+ // All tokens (prompt + generated) for PLD n-gram lookup. Non-master ranks still need to
319
+ // track consistent length for HCCL broadcast of draft proposals.
320
+ std::vector<int32_t> hist;
321
+ hist.reserve(input_ids.size() + args.n_predict + 16);
322
+ for (auto x : input_ids) hist.push_back(x);
323
+ hist.push_back(next_id);
324
+
325
+ // PLD n-gram lookup: search for suffix match ending at end-of-hist; return K tokens following match.
326
+ // Longer matches = more reliable drafts. Multi-level: try n, fall back to smaller n if no match.
327
+ auto lookup_one = [&](int ngram, int K) -> std::vector<int32_t> {
328
+ int hs = (int)hist.size();
329
+ if (hs < ngram + 1 || K <= 0) return {};
330
+ for (int start = hs - ngram - 1; start >= 0; start--) {
331
+ bool match = true;
332
+ for (int k = 0; k < ngram; k++) {
333
+ if (hist[start + k] != hist[hs - ngram + k]) { match = false; break; }
334
+ }
335
+ if (match) {
336
+ int after = start + ngram;
337
+ std::vector<int32_t> d;
338
+ for (int k = 0; k < K && after + k < hs; k++) {
339
+ d.push_back(hist[after + k]);
340
+ if (is_eos(hist[after + k])) break;
341
+ }
342
+ if (!d.empty()) return d;
343
+ }
344
+ }
345
+ return {};
346
+ };
347
+ // Multi-level: try configured n first, then n-1, then n-2 (down to 1).
348
+ auto lookup_draft = [&](int ngram, int K) -> std::vector<int32_t> {
349
+ for (int n = ngram; n >= 1; n--) {
350
+ auto d = lookup_one(n, K);
351
+ if (!d.empty()) return d;
352
+ }
353
+ return {};
354
+ };
355
+
356
+ // Degeneration guard: classify a draft as repetition-induced so we can fall back to single
357
+ // decode (and avoid PLD amplifying model's own repetition loop into a runaway "W W W …" mess).
358
+ // Returns nullptr if draft is OK, else a short reason string for stats.
359
+ auto draft_degenerate = [&](const std::vector<int32_t>& d) -> const char* {
360
+ if (!args.pld_guard || d.empty()) return nullptr;
361
+ // (1) distinct-token count: a draft of K tokens with < args.pld_guard_distinct distinct
362
+ // values means n-gram is echoing a loop. Only apply when draft is long enough.
363
+ if ((int)d.size() >= 3) {
364
+ int distinct = 0;
365
+ for (int i = 0; i < (int)d.size(); i++) {
366
+ bool seen = false;
367
+ for (int j = 0; j < i; j++) { if (d[j] == d[i]) { seen = true; break; } }
368
+ if (!seen) distinct++;
369
+ }
370
+ if (distinct < args.pld_guard_distinct) return "low-distinct";
371
+ }
372
+ // (2) tail echo: if the last N hist tokens are all equal to draft[0], the model is already
373
+ // in a short loop — accepting the draft will just confirm the loop at batch speed.
374
+ int tail_n = std::min(args.pld_guard_tail, (int)hist.size());
375
+ if (tail_n >= 3) {
376
+ int matches = 0;
377
+ for (int i = (int)hist.size() - tail_n; i < (int)hist.size(); i++) {
378
+ if (hist[i] == d[0]) matches++;
379
+ }
380
+ if (matches == tail_n) return "tail-echo";
381
+ }
382
+ return nullptr;
383
+ };
384
+
385
+ // --- Decode loop ---
386
+ auto t2 = std::chrono::steady_clock::now();
387
+ int pld_verifies = 0, pld_accepted = 0;
388
+ int pld_rej_lowdist = 0, pld_rej_tailecho = 0; // guard rejection counters
389
+ bool loop_warned = false; // warn-once state
390
+
391
+ // Adaptive K state: recent accept counts for moving-average decisions
392
+ const int ADAPT_WINDOW = 8;
393
+ std::vector<int> recent_accepts;
394
+ int current_k = args.pld_k;
395
+ bool pld_disabled_adapt = false; // set true when recent accept rate is too low to benefit
396
+
397
+ while (st.decoded < args.n_predict - 1 && !st.hit_eos) {
398
+ // Adaptive K: scale K with recent accept rate.
399
+ // No auto-disable: since S=K+1 forward ≈ S=1 forward (latency-bound), even accept=0.1
400
+ // still nets slightly positive — PLD doesn't "hurt" as long as ngram lookup is cheap.
401
+ if (args.pld_adaptive && (int)recent_accepts.size() >= ADAPT_WINDOW) {
402
+ double avg = 0;
403
+ for (int a : recent_accepts) avg += a;
404
+ avg /= recent_accepts.size();
405
+ // Aim: K = 2*avg + 4 (generous window to catch upswings). Clamp [4, 12].
406
+ current_k = std::max(4, std::min(12, (int)std::round(2.0 * avg + 4.0)));
407
+ }
408
+
409
+ // Try PLD speculation path — skip until enough history accumulated
410
+ std::vector<int32_t> draft;
411
+ if (args.pld_enabled && (int)hist.size() >= args.pld_min_hist && is_master) {
412
+ draft = lookup_draft(args.pld_ngram, current_k);
413
+ // Degeneration guard: if draft looks like repetition-loop echo, drop it so this
414
+ // iteration falls through to normal single decode. This does NOT stop a loop the model
415
+ // is already in (greedy is deterministic), but it prevents PLD from running the loop
416
+ // at batch speed while masquerading as a speedup.
417
+ if (!draft.empty()) {
418
+ const char* reason = draft_degenerate(draft);
419
+ if (reason) {
420
+ if (reason[0] == 'l') pld_rej_lowdist++;
421
+ else pld_rej_tailecho++;
422
+ draft.clear();
423
+ }
424
+ }
425
+ }
426
+ // For TP>1, broadcast draft across ranks. Only broadcast if master has a non-empty draft;
427
+ // otherwise all ranks take the no-draft path (normal decode).
428
+ bool has_draft = is_master ? !draft.empty() : false;
429
+ // Broadcast the "has_draft" flag (using a 1-element count: 1 = yes, 0 = no)
430
+ if (args.tp_size > 1) {
431
+ extern HcclCtx* runner_hccl_ctx_shim(Runner&);
432
+ HcclCtx* ctx = runner_hccl_ctx_shim(runner);
433
+ DeviceBuffer flag(4);
434
+ int32_t f = has_draft ? 1 : 0;
435
+ ACL_CHECK(aclrtMemcpy(flag.get(), 4, &f, 4, ACL_MEMCPY_HOST_TO_DEVICE));
436
+ hccl_broadcast(*ctx, flag.get(), 1, HCCL_DATA_TYPE_INT32, 0, runner.stream());
437
+ ACL_CHECK(aclrtSynchronizeStream(runner.stream()));
438
+ ACL_CHECK(aclrtMemcpy(&f, 4, flag.get(), 4, ACL_MEMCPY_DEVICE_TO_HOST));
439
+ has_draft = (f != 0);
440
+ if (has_draft) {
441
+ std::vector<int32_t> d = draft;
442
+ broadcast_token_ids(runner, d, args.max_seq, is_master);
443
+ draft = d;
444
+ } else {
445
+ draft.clear();
446
+ }
447
+ }
448
+
449
+ if (args.pld_enabled && (int)draft.size() >= 1 && args.temperature == 0.0f) {
450
+ // Batch verify: input = [next_id, draft[0], ..., draft[K-1]]
451
+ std::vector<int32_t> batch_input = { next_id };
452
+ for (auto d : draft) batch_input.push_back(d);
453
+ int S = (int)batch_input.size();
454
+ DeviceBuffer batch_logits;
455
+ if (!runner.decode_batch(batch_input.data(), S, batch_logits)) break;
456
+
457
+ std::vector<uint16_t> blh(S * V);
458
+ if (is_master) ACL_CHECK(aclrtMemcpy(blh.data(), S*V*2, batch_logits.get(), S*V*2, ACL_MEMCPY_DEVICE_TO_HOST));
459
+
460
+ // Accept longest prefix: draft[i] is "candidate" for position past+i+1.
461
+ // blh row i predicts position past+i+1 (follows batch_input[i]).
462
+ // Verify: blh[0].argmax == draft[0]? (i.e., does model agree with draft's first proposal)
463
+ int accept = 0, new_next = next_id;
464
+ if (is_master) {
465
+ for (int i = 0; i < S - 1; i++) {
466
+ int pred = 0; float bv = bf16_to_float(blh[i * V]);
467
+ for (int k = 1; k < V; k++) { float v = bf16_to_float(blh[i*V + k]); if (v > bv) { bv = v; pred = k; } }
468
+ if (pred == (int)draft[i]) accept++;
469
+ else { new_next = pred; break; }
470
+ }
471
+ if (accept == S - 1) {
472
+ // All draft accepted, bonus from last row
473
+ int pred = 0; float bv = bf16_to_float(blh[(S-1) * V]);
474
+ for (int k = 1; k < V; k++) { float v = bf16_to_float(blh[(S-1)*V + k]); if (v > bv) { bv = v; pred = k; } }
475
+ new_next = pred;
476
+ }
477
+ }
478
+
479
+ // Broadcast accept count + new_next across TP ranks
480
+ if (args.tp_size > 1) {
481
+ int32_t packed[2] = { (int32_t)accept, (int32_t)new_next };
482
+ std::vector<int32_t> p(packed, packed + 2);
483
+ broadcast_token_ids(runner, p, args.max_seq, is_master);
484
+ if (p.size() == 2) { accept = p[0]; new_next = p[1]; }
485
+ }
486
+
487
+ // Rewind KV for rejected drafts
488
+ int64_t rewind = (int64_t)(S - 1 - accept); // drafts not accepted (excluding bonus)
489
+ if (rewind > 0) runner.rewind_cache(rewind);
490
+
491
+ // Commit accepted drafts + bonus to hist and emit
492
+ for (int i = 0; i < accept; i++) {
493
+ int tok = (int)draft[i];
494
+ hist.push_back(tok);
495
+ generated.push_back(tok);
496
+ st.decoded++;
497
+ if (is_master && args.stream) { printf("%s", tokenizer.decode(tok).c_str()); fflush(stdout); }
498
+ if (is_eos(tok)) { st.hit_eos = true; break; }
499
+ }
500
+ pld_verifies++; pld_accepted += accept;
501
+ // Track recent accept for adaptive K
502
+ if (args.pld_adaptive) {
503
+ recent_accepts.push_back(accept);
504
+ if ((int)recent_accepts.size() > ADAPT_WINDOW) recent_accepts.erase(recent_accepts.begin());
505
+ }
506
+ if (st.hit_eos) break;
507
+
508
+ // Bonus token (new_next) is also committed
509
+ hist.push_back(new_next);
510
+ generated.push_back(new_next);
511
+ st.decoded++;
512
+ if (is_master && args.stream) { printf("%s", tokenizer.decode(new_next).c_str()); fflush(stdout); }
513
+ if (is_eos(new_next)) { st.hit_eos = true; break; }
514
+ next_id = new_next;
515
+ } else {
516
+ // Normal decode
517
+ DeviceBuffer logits2;
518
+ if (!runner.decode((int32_t)next_id, logits2)) break;
519
+ load_logits(logits2);
520
+ next_id = sample_token(logits_h, V, args.temperature, args.top_k, args.top_p, rng);
521
+ hist.push_back(next_id);
522
+ generated.push_back(next_id);
523
+ st.decoded++;
524
+ if (is_master && args.stream) { printf("%s", tokenizer.decode(next_id).c_str()); fflush(stdout); }
525
+ if (is_eos(next_id)) { st.hit_eos = true; break; }
526
+ }
527
+ // Loop-warn: emit a one-shot warning to stderr if the tail of generated is all-same-token.
528
+ // Does not stop generation (user may want to see what happens) — just flags output quality.
529
+ if (is_master && !loop_warned && args.pld_loop_warn > 0 &&
530
+ (int)generated.size() >= args.pld_loop_warn) {
531
+ int tail = args.pld_loop_warn;
532
+ int anchor = generated[(int)generated.size() - tail];
533
+ bool all_same = true;
534
+ for (int i = (int)generated.size() - tail + 1; i < (int)generated.size(); i++) {
535
+ if (generated[i] != anchor) { all_same = false; break; }
536
+ }
537
+ if (all_same) {
538
+ fprintf(stderr, "\n[warn] %d consecutive identical tokens — likely degeneration loop; output after this point is suspect\n", tail);
539
+ loop_warned = true;
540
+ }
541
+ }
542
+ }
543
+ auto t3 = std::chrono::steady_clock::now();
544
+ st.decode_ms = std::chrono::duration<double, std::milli>(t3 - t2).count();
545
+ if (is_master && args.pld_enabled) {
546
+ if (pld_verifies > 0) {
547
+ fprintf(stderr, "\n[pld] %d verifies, %d drafts accepted, avg=%.2f",
548
+ pld_verifies, pld_accepted, (double)pld_accepted / pld_verifies);
549
+ } else {
550
+ fprintf(stderr, "\n[pld] 0 verifies (all drafts blocked or none found)");
551
+ }
552
+ if (args.pld_guard && (pld_rej_lowdist + pld_rej_tailecho) > 0) {
553
+ fprintf(stderr, "; guard rejections: low-distinct=%d tail-echo=%d",
554
+ pld_rej_lowdist, pld_rej_tailecho);
555
+ }
556
+ fprintf(stderr, "\n");
557
+ }
558
+
559
+ if (is_master) {
560
+ if (args.stream) { printf("\n"); fflush(stdout); }
561
+ else {
562
+ std::string text = tokenizer.decode(generated);
563
+ printf("%s%s\n", args.chat_template ? "" : prompt.c_str(), text.c_str());
564
+ }
565
+ }
566
+ return st;
567
+ }
568
+
569
+ static bool load_file(const std::string& path, std::string& out) {
570
+ FILE* f = fopen(path.c_str(), "rb");
571
+ if (!f) { fprintf(stderr, "[cli] cannot open %s\n", path.c_str()); return false; }
572
+ fseek(f, 0, SEEK_END); long sz = ftell(f); fseek(f, 0, SEEK_SET);
573
+ out.resize(sz);
574
+ size_t n = fread(out.data(), 1, sz, f);
575
+ fclose(f);
576
+ if ((long)n != sz) { fprintf(stderr, "[cli] short read from %s\n", path.c_str()); return false; }
577
+ // Strip a single trailing newline (common in text files)
578
+ if (!out.empty() && out.back() == '\n') out.pop_back();
579
+ return true;
580
+ }
581
+
582
+ int main(int argc, char** argv) {
583
+ Args args;
584
+ if (!parse_args(argc, argv, args)) return 1;
585
+
586
+ // --prompt-file overrides --prompt
587
+ if (!args.prompt_file.empty()) {
588
+ if (!load_file(args.prompt_file, args.prompt)) return 1;
589
+ }
590
+
591
+ const bool is_master = (args.tp_rank == 0);
592
+ std::mt19937 rng(args.seed ? args.seed :
593
+ (uint64_t)std::chrono::steady_clock::now().time_since_epoch().count());
594
+
595
+ if (is_master) {
596
+ printf("[cli] model=%s\n", args.model_dir.c_str());
597
+ printf("[cli] tp=%d n_predict=%d temp=%.2f top_k=%d top_p=%.2f chat=%d interactive=%d\n",
598
+ args.tp_size, args.n_predict, args.temperature, args.top_k, args.top_p,
599
+ args.chat_template, args.interactive);
600
+ fflush(stdout);
601
+ }
602
+
603
+ Tokenizer tokenizer;
604
+ if (!tokenizer.load(args.vocab_path)) {
605
+ fprintf(stderr, "[cli] failed to load vocab %s\n", args.vocab_path.c_str()); return 1;
606
+ }
607
+
608
+ Runner runner;
609
+ int num_layers = args.num_layers;
610
+ if (num_layers == 0) {
611
+ ModelConfig probe;
612
+ if (!probe.load_from_json(args.model_dir + "/config.json")) return 1;
613
+ num_layers = (int)probe.num_hidden_layers;
614
+ }
615
+ if (!runner.init(args.model_dir, args.tp_size, args.tp_rank,
616
+ num_layers, args.max_seq, args.device_id)) return 1;
617
+ if (const char* p = std::getenv("LCA_PROFILE"); p && std::atoi(p) != 0) {
618
+ runner.profile_enabled = true;
619
+ }
620
+ // Warmup: cut cold-start latency. Controlled via LCA_WARMUP env (default 0 to keep behavior).
621
+ if (const char* w = std::getenv("LCA_WARMUP"); w) {
622
+ int n = std::atoi(w);
623
+ if (n > 0) runner.warmup(n);
624
+ }
625
+
626
+ if (args.interactive) {
627
+ const bool multi_turn = args.chat_template && !args.reset_each_turn;
628
+ if (is_master) {
629
+ printf("\n[cli] === interactive mode ===\n");
630
+ if (multi_turn) {
631
+ printf("[cli] multi-turn chat (KV cache preserved). Commands: 'quit', 'reset'.\n");
632
+ if (!args.system_prompt.empty()) {
633
+ printf("[cli] system: %s\n", args.system_prompt.c_str());
634
+ }
635
+ } else {
636
+ printf("[cli] stateless mode (KV cache reset each turn). Command: 'quit'.\n");
637
+ if (!args.chat_template) {
638
+ printf("[cli] (hint: add --chat for multi-turn conversational memory)\n");
639
+ }
640
+ }
641
+ fflush(stdout);
642
+ }
643
+
644
+ // Conversation history: accumulated (role, content) pairs. System prompt seeded if present.
645
+ std::vector<std::pair<std::string, std::string>> conversation;
646
+ if (multi_turn && !args.system_prompt.empty()) {
647
+ conversation.emplace_back("system", args.system_prompt);
648
+ }
649
+
650
+ auto* hccl_ctx = runner_hccl_ctx_shim(runner);
651
+
652
+ // Signal types (broadcast as int32): 0 = normal turn, 1 = quit, 2 = reset.
653
+ auto broadcast_signal = [&](int32_t sig)->int32_t {
654
+ if (args.tp_size <= 1) return sig;
655
+ DeviceBuffer s(4);
656
+ ACL_CHECK(aclrtMemcpy(s.get(), 4, &sig, 4, ACL_MEMCPY_HOST_TO_DEVICE));
657
+ hccl_broadcast(*hccl_ctx, s.get(), 1, HCCL_DATA_TYPE_INT32, 0, runner.stream());
658
+ ACL_CHECK(aclrtSynchronizeStream(runner.stream()));
659
+ int32_t r; ACL_CHECK(aclrtMemcpy(&r, 4, s.get(), 4, ACL_MEMCPY_DEVICE_TO_HOST));
660
+ return r;
661
+ };
662
+
663
+ while (true) {
664
+ std::string prompt;
665
+ int32_t sig = 0;
666
+ if (is_master) {
667
+ printf("\n> "); fflush(stdout);
668
+ if (!std::getline(std::cin, prompt)) sig = 1;
669
+ else if (prompt == "quit" || prompt == "exit") sig = 1;
670
+ else if (prompt == "reset") sig = 2;
671
+ else if (prompt.empty()) sig = 3; // skip
672
+ }
673
+ sig = broadcast_signal(sig);
674
+ if (sig == 1) break;
675
+ if (sig == 2) {
676
+ runner.reset_cache();
677
+ conversation.clear();
678
+ if (multi_turn && !args.system_prompt.empty())
679
+ conversation.emplace_back("system", args.system_prompt);
680
+ if (is_master) { printf("[cli] (cache + conversation reset)\n"); fflush(stdout); }
681
+ continue;
682
+ }
683
+ if (sig == 3) continue;
684
+
685
+ TurnStats st;
686
+ if (multi_turn) {
687
+ // Append user message and tokenize full conversation. Prefill DELTA only.
688
+ if (is_master) conversation.emplace_back("user", prompt);
689
+ // Also ranks 1..N-1 need to track conversation (needed for correct delta count on
690
+ // subsequent turns if TP ever tokenizes per-rank — currently rank 0 tokenizes).
691
+ std::vector<int32_t> full_ids;
692
+ if (is_master) {
693
+ auto raw = tokenizer.encode_conversation_via_python(args.model_dir, conversation, /*gen_prompt=*/true);
694
+ full_ids.reserve(raw.size());
695
+ for (int v : raw) full_ids.push_back((int32_t)v);
696
+ }
697
+ // Broadcast full_ids (variable-length). Use the same shim as broadcast_token_ids.
698
+ if (args.tp_size > 1) {
699
+ if (!broadcast_token_ids(runner, full_ids, args.max_seq, is_master)) break;
700
+ }
701
+ if (full_ids.empty()) { if (is_master) printf("[cli] tokenize failed\n"); continue; }
702
+
703
+ int64_t past = runner.past_len();
704
+ if ((int64_t)full_ids.size() < past) { runner.reset_cache(); past = 0; }
705
+ std::vector<int32_t> delta(full_ids.begin() + past, full_ids.end());
706
+ if (delta.empty()) {
707
+ if (is_master) printf("[cli] (no new tokens)\n");
708
+ continue;
709
+ }
710
+ // Overflow check — simple policy: warn + auto-reset if the turn + generation
711
+ // would exceed max_seq. Conversation history is cleared (except --system) so
712
+ // the user's current prompt still fits.
713
+ if ((int64_t)(past + delta.size()) + args.n_predict > args.max_seq) {
714
+ if (is_master) {
715
+ fprintf(stderr, "[cli] context %ld + gen %d > max_seq %d — auto-resetting\n",
716
+ (long)(past + delta.size()), args.n_predict, args.max_seq);
717
+ }
718
+ runner.reset_cache();
719
+ // Rebuild conversation: keep only system + current user turn.
720
+ if (is_master) {
721
+ std::vector<std::pair<std::string, std::string>> fresh;
722
+ for (auto& m : conversation) if (m.first == "system") fresh.push_back(m);
723
+ if (!conversation.empty() && conversation.back().first == "user") {
724
+ fresh.push_back(conversation.back());
725
+ }
726
+ conversation = std::move(fresh);
727
+ auto raw = tokenizer.encode_conversation_via_python(args.model_dir, conversation, true);
728
+ full_ids.clear();
729
+ for (int v : raw) full_ids.push_back((int32_t)v);
730
+ }
731
+ if (args.tp_size > 1) {
732
+ if (!broadcast_token_ids(runner, full_ids, args.max_seq, is_master)) break;
733
+ }
734
+ delta.assign(full_ids.begin(), full_ids.end());
735
+ past = 0;
736
+ }
737
+
738
+ // --- Prefill the delta ---
739
+ st.n_prompt = (int)delta.size();
740
+ auto t0 = std::chrono::steady_clock::now();
741
+ DeviceBuffer logits;
742
+ if (!runner.prefill(delta.data(), (int64_t)delta.size(), logits)) break;
743
+ auto t1 = std::chrono::steady_clock::now();
744
+ st.prefill_ms = std::chrono::duration<double, std::milli>(t1 - t0).count();
745
+
746
+ const int64_t V = runner.cfg().vocab_size;
747
+ std::vector<uint16_t> logits_h(V);
748
+ auto load_logits = [&](DeviceBuffer& buf) {
749
+ ACL_CHECK(aclrtMemcpy(logits_h.data(), V*2, buf.get(), V*2, ACL_MEMCPY_DEVICE_TO_HOST));
750
+ };
751
+ auto is_eos = [&](int id) {
752
+ for (int e : args.eos_ids) if (id == e) return true;
753
+ return false;
754
+ };
755
+ load_logits(logits);
756
+ int next_id = sample_token(logits_h, V, args.temperature, args.top_k, args.top_p, rng);
757
+
758
+ std::vector<int> assistant_ids = { next_id };
759
+ if (is_master) { printf("%s", tokenizer.decode(next_id).c_str()); fflush(stdout); }
760
+ st.hit_eos = is_eos(next_id);
761
+
762
+ auto t2 = std::chrono::steady_clock::now();
763
+ for (int step = 1; step < args.n_predict && !st.hit_eos; step++) {
764
+ DeviceBuffer logits2;
765
+ if (!runner.decode((int32_t)next_id, logits2)) break;
766
+ load_logits(logits2);
767
+ next_id = sample_token(logits_h, V, args.temperature, args.top_k, args.top_p, rng);
768
+ assistant_ids.push_back(next_id);
769
+ st.decoded++;
770
+ if (is_master) { printf("%s", tokenizer.decode(next_id).c_str()); fflush(stdout); }
771
+ if (is_eos(next_id)) { st.hit_eos = true; break; }
772
+ }
773
+ auto t3 = std::chrono::steady_clock::now();
774
+ st.decode_ms = std::chrono::duration<double, std::milli>(t3 - t2).count();
775
+ if (is_master) { printf("\n"); fflush(stdout); }
776
+
777
+ // Record assistant reply in conversation (strip trailing EOS before decode,
778
+ // and trim incomplete UTF-8 tail if generation was cut mid-codepoint).
779
+ if (is_master) {
780
+ std::vector<int> content_ids;
781
+ for (int id : assistant_ids) { if (is_eos(id)) break; content_ids.push_back(id); }
782
+ conversation.emplace_back("assistant", utf8_trim_incomplete(tokenizer.decode(content_ids)));
783
+ }
784
+ } else {
785
+ // Stateless: reset cache, one-shot prompt
786
+ runner.reset_cache();
787
+ st = run_turn(runner, tokenizer, args, prompt, rng, is_master);
788
+ }
789
+
790
+ if (is_master) {
791
+ double tgs = (st.decode_ms > 0) ? (st.decoded * 1000.0 / st.decode_ms) : 0.0;
792
+ printf("[perf] prefill %d tok %.0fms decode %d tok %.0fms = %.2f t/s%s past_len=%ld\n",
793
+ st.n_prompt, st.prefill_ms, st.decoded, st.decode_ms, tgs,
794
+ st.hit_eos ? " (EOS)" : "", runner.past_len());
795
+ fflush(stdout);
796
+ }
797
+ }
798
+ if (is_master) printf("[cli] bye\n");
799
+ return 0;
800
+ }
801
+
802
+ // One-shot mode
803
+ TurnStats st = run_turn(runner, tokenizer, args, args.prompt, rng, is_master);
804
+ if (is_master) runner.print_profile_summary();
805
+ if (is_master) {
806
+ if (st.hit_eos) printf("[cli] (hit EOS)\n");
807
+ printf("\n[perf] prefill: %.1fms for %d tokens = %.2f t/s\n",
808
+ st.prefill_ms, st.n_prompt,
809
+ (st.prefill_ms > 0) ? (st.n_prompt * 1000.0 / st.prefill_ms) : 0.0);
810
+ if (st.decoded > 0) {
811
+ printf("[perf] decode : %.1fms for %d tokens = %.2f t/s (TG)\n",
812
+ st.decode_ms, st.decoded, (st.decoded * 1000.0) / st.decode_ms);
813
+ }
814
+ }
815
+ return 0;
816
+ }
src/model_config.cpp ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "model_config.h"
2
+
3
+ #include <cstdio>
4
+ #include <fstream>
5
+ #include <sstream>
6
+
7
+ #include "json.hpp"
8
+ using json = nlohmann::json;
9
+
10
+ bool ModelConfig::load_from_json(const std::string& path) {
11
+ std::ifstream f(path);
12
+ if (!f) {
13
+ fprintf(stderr, "ModelConfig: cannot open %s\n", path.c_str());
14
+ return false;
15
+ }
16
+ json j;
17
+ try { f >> j; } catch (std::exception& e) {
18
+ fprintf(stderr, "ModelConfig: bad json: %s\n", e.what());
19
+ return false;
20
+ }
21
+
22
+ auto get = [&](const char* k, auto def) {
23
+ if (j.contains(k) && !j[k].is_null()) return j[k].get<decltype(def)>();
24
+ return def;
25
+ };
26
+
27
+ vocab_size = get("vocab_size", (int64_t)0);
28
+ hidden_size = get("hidden_size", (int64_t)0);
29
+ intermediate_size = get("intermediate_size", (int64_t)0);
30
+ moe_intermediate_size = get("moe_intermediate_size", (int64_t)0);
31
+ num_hidden_layers = get("num_hidden_layers", (int64_t)0);
32
+ num_attention_heads = get("num_attention_heads", (int64_t)0);
33
+ num_key_value_heads = get("num_key_value_heads", (int64_t)0);
34
+ head_dim = get("head_dim", (int64_t)0);
35
+ num_experts = get("num_experts", (int64_t)0);
36
+ num_experts_per_tok = get("num_experts_per_tok", (int64_t)0);
37
+ max_position_embeddings = get("max_position_embeddings", (int64_t)0);
38
+ rope_theta = (float)get("rope_theta", (double)10000.0);
39
+ rms_norm_eps = (float)get("rms_norm_eps", (double)1e-6);
40
+ norm_topk_prob = get("norm_topk_prob", true);
41
+ tie_word_embeddings = get("tie_word_embeddings", false);
42
+ bos_token_id = get("bos_token_id", (int64_t)0);
43
+ eos_token_id = get("eos_token_id", (int64_t)0);
44
+
45
+ // Sanity
46
+ if (num_attention_heads == 0 || head_dim == 0 || hidden_size == 0) {
47
+ fprintf(stderr, "ModelConfig: required fields missing\n");
48
+ return false;
49
+ }
50
+ return true;
51
+ }
52
+
53
+ void ModelConfig::compute_derived(int tps, int tpr) {
54
+ tp_size = tps;
55
+ tp_rank = tpr;
56
+
57
+ // Attention Q: split by head
58
+ if (num_attention_heads % tp_size != 0) {
59
+ fprintf(stderr, "WARN: num_attention_heads=%ld not divisible by tp_size=%d\n",
60
+ num_attention_heads, tp_size);
61
+ }
62
+ n_heads_per_rank = num_attention_heads / tp_size;
63
+ q_dim_per_rank = n_heads_per_rank * head_dim;
64
+
65
+ // Attention KV: GQA sharding.
66
+ // Case A (tp_size <= num_kv_heads): split KV heads across ranks.
67
+ // n_kv_heads_per_rank = num_kv_heads / tp_size
68
+ // Case B (tp_size > num_kv_heads): each rank gets ONE kv head shared by multiple ranks.
69
+ // Ranks in the same "group" share one kv head (ratio = tp_size / num_kv_heads).
70
+ // n_kv_heads_per_rank = 1
71
+ // kv_head_idx_for_rank = tp_rank / (tp_size / num_kv_heads)
72
+ // This matches the GQA semantics: each group of (num_q_heads / num_kv_heads) Q heads
73
+ // shares one KV head. FIAS is given matched Hq (rank-local Q heads) and Hkv=1.
74
+ if (tp_size <= num_key_value_heads && num_key_value_heads % tp_size == 0) {
75
+ n_kv_heads_per_rank = num_key_value_heads / tp_size;
76
+ } else if (tp_size % num_key_value_heads == 0) {
77
+ n_kv_heads_per_rank = 1;
78
+ } else {
79
+ fprintf(stderr, "WARN: non-standard TP/KV head ratio: tp=%d kv=%ld — falling back to replicate-all\n",
80
+ tp_size, num_key_value_heads);
81
+ n_kv_heads_per_rank = num_key_value_heads;
82
+ }
83
+ kv_dim_per_rank = n_kv_heads_per_rank * head_dim;
84
+
85
+ // MoE intermediate dim split
86
+ if (moe_intermediate_size % tp_size != 0) {
87
+ fprintf(stderr, "WARN: moe_intermediate_size=%ld not divisible by tp_size=%d\n",
88
+ moe_intermediate_size, tp_size);
89
+ }
90
+ i_per_rank = moe_intermediate_size / tp_size;
91
+ }
92
+
93
+ std::string ModelConfig::describe() const {
94
+ std::ostringstream os;
95
+ os << "Qwen3MoE config:\n"
96
+ << " vocab_size = " << vocab_size << "\n"
97
+ << " hidden_size = " << hidden_size << "\n"
98
+ << " num_hidden_layers = " << num_hidden_layers << "\n"
99
+ << " num_attention_heads = " << num_attention_heads << "\n"
100
+ << " num_key_value_heads = " << num_key_value_heads << "\n"
101
+ << " head_dim = " << head_dim << "\n"
102
+ << " num_experts = " << num_experts << "\n"
103
+ << " num_experts_per_tok = " << num_experts_per_tok << "\n"
104
+ << " moe_intermediate_size = " << moe_intermediate_size << "\n"
105
+ << " rope_theta = " << rope_theta << "\n"
106
+ << " rms_norm_eps = " << rms_norm_eps << "\n"
107
+ << " max_pos_embeddings = " << max_position_embeddings << "\n"
108
+ << "TP rank " << tp_rank << " / " << tp_size << " derived:\n"
109
+ << " n_heads_per_rank = " << n_heads_per_rank << "\n"
110
+ << " q_dim_per_rank = " << q_dim_per_rank << "\n"
111
+ << " n_kv_heads_per_rank = " << n_kv_heads_per_rank << "\n"
112
+ << " kv_dim_per_rank = " << kv_dim_per_rank << "\n"
113
+ << " i_per_rank = " << i_per_rank << "\n";
114
+ return os.str();
115
+ }
src/runner.cpp ADDED
@@ -0,0 +1,428 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "runner.h"
2
+
3
+ #include <chrono>
4
+ #include <cstdio>
5
+ #include <cstring>
6
+
7
+ // Expose HCCL context for the CLI broadcast helper.
8
+ HcclCtx* runner_hccl_ctx_shim(Runner& r) { return &r.hccl_ctx(); }
9
+
10
+ bool Runner::init(const std::string& model_dir, int tp_size, int tp_rank,
11
+ int num_layers_to_load, int64_t max_seq, int device_id) {
12
+ if (!cfg_.load_from_json(model_dir + "/config.json")) return false;
13
+ cfg_.compute_derived(tp_size, tp_rank);
14
+ if (num_layers_to_load < 1 || num_layers_to_load > (int)cfg_.num_hidden_layers) {
15
+ fprintf(stderr, "runner: invalid num_layers %d (max %ld)\n",
16
+ num_layers_to_load, cfg_.num_hidden_layers);
17
+ return false;
18
+ }
19
+ num_layers_ = num_layers_to_load;
20
+ max_seq_ = max_seq;
21
+
22
+ if (!st_.open(model_dir)) return false;
23
+ rt_.init(device_id);
24
+
25
+ // HCCL init (no-op if tp_size == 1)
26
+ if (!hccl_init(hccl_ctx_, tp_size, tp_rank)) {
27
+ fprintf(stderr, "runner: HCCL init failed\n");
28
+ return false;
29
+ }
30
+
31
+ DeviceWeightsLoader dw(st_, cfg_);
32
+ printf("runner: loading shared weights (embed, lm_head, final_norm)...\n");
33
+ if (!dw.load_shared(shared_)) return false;
34
+
35
+ attn_.resize(num_layers_);
36
+ moe_.resize(num_layers_);
37
+ k_cache_.resize(num_layers_);
38
+ v_cache_.resize(num_layers_);
39
+
40
+ const int64_t KV_DIM = cfg_.n_kv_heads_per_rank * cfg_.head_dim;
41
+ for (int L = 0; L < num_layers_; L++) {
42
+ printf("runner: loading layer %d/%d...\n", L + 1, num_layers_);
43
+ if (!dw.load_attention(L, attn_[L])) return false;
44
+ if (!dw.load_moe(L, rt_.stream(), moe_[L])) return false;
45
+ k_cache_[L].alloc(max_seq_ * KV_DIM * 2);
46
+ v_cache_[L].alloc(max_seq_ * KV_DIM * 2);
47
+ }
48
+ rt_.sync();
49
+
50
+ // Prefill mask (2048x2048 bool causal)
51
+ const int64_t MASK = 2048;
52
+ std::vector<uint8_t> mh(MASK * MASK, 0);
53
+ for (int i = 0; i < MASK; i++)
54
+ for (int j = i+1; j < MASK; j++) mh[i*MASK + j] = 1;
55
+ prefill_mask_dev_.alloc(MASK * MASK);
56
+ ACL_CHECK(aclrtMemcpy(prefill_mask_dev_.get(), MASK*MASK, mh.data(), MASK*MASK, ACL_MEMCPY_HOST_TO_DEVICE));
57
+
58
+ // Pre-compute RoPE cos/sin table once (covers all positions up to max_seq_)
59
+ rope_cache_build(rope_cache_, max_seq_, cfg_.head_dim, cfg_.rope_theta);
60
+
61
+ past_len_ = 0;
62
+ cur_S_capacity_ = 0;
63
+ return true;
64
+ }
65
+
66
+ static void ensure_sc_(DeviceBuffer& buf, size_t needed) {
67
+ if (buf.size < needed) buf.alloc(needed);
68
+ }
69
+
70
+ static void ensure_all_scratch_(Runner* self, int64_t S, const ModelConfig& cfg,
71
+ DeviceBuffer& q_sc, DeviceBuffer& k_sc, DeviceBuffer& v_sc,
72
+ DeviceBuffer& xn_sc, DeviceBuffer& rstd_sc, DeviceBuffer& rope_sc,
73
+ DeviceBuffer& attn_fias_sc, DeviceBuffer& attn_out_sc,
74
+ DeviceBuffer& moe_xn, DeviceBuffer& moe_rstd, DeviceBuffer& moe_logits,
75
+ DeviceBuffer& moe_topk_w, DeviceBuffer& moe_topk_idx, DeviceBuffer& moe_row_idx,
76
+ DeviceBuffer& moe_ex_x, DeviceBuffer& moe_ex_ri, DeviceBuffer& moe_tpe,
77
+ DeviceBuffer& moe_fwd,
78
+ DeviceBuffer& moe_gate, DeviceBuffer& moe_up, DeviceBuffer& moe_down,
79
+ DeviceBuffer& moe_packed, DeviceBuffer& moe_weighted, DeviceBuffer& moe_out,
80
+ DeviceBuffer& moe_norm_sum,
81
+ DeviceBuffer& x_buf_a, DeviceBuffer& x_buf_b) {
82
+ (void)self;
83
+ const int64_t D = cfg.hidden_size;
84
+ const int64_t Hq = cfg.n_heads_per_rank, Hkv = cfg.n_kv_heads_per_rank;
85
+ const int64_t Dh = cfg.head_dim;
86
+ const int64_t Q_DIM = Hq * Dh;
87
+ const int64_t KV_DIM = Hkv * Dh;
88
+ const int64_t I = cfg.i_per_rank, E = cfg.num_experts, K = cfg.num_experts_per_tok;
89
+ const int64_t TOTAL = S * K;
90
+
91
+ ensure_sc_(q_sc, S * Q_DIM * 2);
92
+ ensure_sc_(k_sc, S * KV_DIM * 2);
93
+ ensure_sc_(v_sc, S * KV_DIM * 2);
94
+ ensure_sc_(xn_sc, S * D * 2);
95
+ ensure_sc_(rstd_sc, S * std::max(Hq, Hkv) * 4);
96
+ ensure_sc_(rope_sc, 1 * S * Hq * Dh * 2);
97
+ ensure_sc_(attn_fias_sc, S * Q_DIM * 2);
98
+ ensure_sc_(attn_out_sc, S * D * 2);
99
+
100
+ ensure_sc_(moe_xn, S * D * 2);
101
+ ensure_sc_(moe_rstd, S * 4);
102
+ ensure_sc_(moe_logits, S * E * 2);
103
+ ensure_sc_(moe_topk_w, S * K * 2);
104
+ ensure_sc_(moe_topk_idx, S * K * 4);
105
+ ensure_sc_(moe_row_idx, S * K * 4);
106
+ ensure_sc_(moe_ex_x, TOTAL * D * 2);
107
+ ensure_sc_(moe_ex_ri, TOTAL * 4);
108
+ ensure_sc_(moe_tpe, E * 8);
109
+ ensure_sc_(moe_fwd, TOTAL * 8);
110
+ ensure_sc_(moe_gate, TOTAL * I * 2);
111
+ ensure_sc_(moe_up, TOTAL * I * 2);
112
+ ensure_sc_(moe_down, TOTAL * D * 2);
113
+ ensure_sc_(moe_packed, TOTAL * D * 2);
114
+ ensure_sc_(moe_weighted, S * K * D * 2);
115
+ ensure_sc_(moe_out, S * D * 2);
116
+ ensure_sc_(moe_norm_sum, S * 2);
117
+
118
+ ensure_sc_(x_buf_a, S * D * 2);
119
+ ensure_sc_(x_buf_b, S * D * 2);
120
+ }
121
+
122
+ void Runner::layer_forward_(int layer_idx, int64_t S, void* x_in, void* x_out, bool batch_decode_mode) {
123
+ const int64_t D = cfg_.hidden_size;
124
+
125
+ // Attention mask selection:
126
+ // prefill (S>1, past=0): 2048×2048 upper-tri + sparse_mode=3 (FIAS internal causal)
127
+ // decode (S==1): mask=nullptr + sparse_mode=0 (single query sees all cache)
128
+ // batch decode (S>1, past>0): S × (past+S) causal-with-past + sparse_mode=0
129
+ aclTensor* mask = nullptr;
130
+ int64_t sparse_mode = -1; // auto
131
+ AclTensorPtr t_mask_ptr;
132
+ if (batch_decode_mode) {
133
+ build_batch_decode_mask_(S);
134
+ int64_t kv_len = past_len_ + S;
135
+ t_mask_ptr = make_contig_tensor(batch_mask_dev_.get(), ACL_BOOL, {1, 1, S, kv_len});
136
+ mask = t_mask_ptr.get();
137
+ sparse_mode = 0;
138
+ } else if (S > 1) {
139
+ // Pure prefill from past=0
140
+ t_mask_ptr = make_contig_tensor(prefill_mask_dev_.get(), ACL_BOOL, {1, 1, 2048, 2048});
141
+ mask = t_mask_ptr.get();
142
+ sparse_mode = 3;
143
+ }
144
+ // else: S=1 decode, mask=nullptr, sparse_mode=0 (auto)
145
+
146
+ attention_forward(
147
+ rt_.stream(), cfg_, attn_[layer_idx],
148
+ x_in, S, past_len_,
149
+ k_cache_[layer_idx].get(), v_cache_[layer_idx].get(), max_seq_,
150
+ mask,
151
+ q_sc_.get(), k_sc_.get(), v_sc_.get(),
152
+ xn_sc_.get(), rstd_sc_.get(), rope_sc_.get(),
153
+ attn_fias_sc_.get(),
154
+ attn_out_sc_.get(),
155
+ (hccl_ctx_.tp_size > 1) ? &hccl_ctx_ : nullptr,
156
+ &rope_cache_,
157
+ sparse_mode);
158
+
159
+ // x1 = x_in + attn_out (residual)
160
+ auto t_x_in = make_contig_tensor(x_in, ACL_BF16, {S, D});
161
+ auto t_attn_out= make_contig_tensor(attn_out_sc_.get(), ACL_BF16, {S, D});
162
+ auto t_x1 = make_contig_tensor(x_buf_a_.get(), ACL_BF16, {S, D});
163
+ {
164
+ float a = 1.0f; aclScalar* al = aclCreateScalar(&a, ACL_FLOAT);
165
+ uint64_t ws = 0; aclOpExecutor* e = nullptr;
166
+ ACLNN_CHECK(aclnnAddGetWorkspaceSize(t_x_in.get(), t_attn_out.get(), al, t_x1.get(), &ws, &e));
167
+ DeviceBuffer wb; if (ws > 0) wb.alloc(ws);
168
+ ACLNN_CHECK(aclnnAdd(wb.get(), ws, e, rt_.stream()));
169
+ aclDestroyScalar(al);
170
+ }
171
+
172
+ // MoE
173
+ moe_forward(
174
+ rt_.stream(), cfg_, attn_[layer_idx], moe_[layer_idx],
175
+ x_buf_a_.get(), S,
176
+ moe_xn_.get(), moe_rstd_.get(),
177
+ moe_logits_.get(),
178
+ moe_topk_w_.get(), moe_topk_idx_.get(), moe_row_idx_.get(),
179
+ moe_ex_x_.get(), moe_ex_ri_.get(), moe_tpe_.get(),
180
+ moe_fwd_.get(),
181
+ moe_gate_.get(), moe_up_.get(), moe_down_.get(),
182
+ moe_packed_.get(), moe_weighted_.get(),
183
+ moe_out_.get(),
184
+ (hccl_ctx_.tp_size > 1) ? &hccl_ctx_ : nullptr,
185
+ moe_norm_sum_.get());
186
+
187
+ // x_out = x1 + moe_out (residual)
188
+ auto t_moe_out = make_contig_tensor(moe_out_.get(), ACL_BF16, {S, D});
189
+ auto t_out = make_contig_tensor(x_out, ACL_BF16, {S, D});
190
+ {
191
+ float a = 1.0f; aclScalar* al = aclCreateScalar(&a, ACL_FLOAT);
192
+ uint64_t ws = 0; aclOpExecutor* e = nullptr;
193
+ ACLNN_CHECK(aclnnAddGetWorkspaceSize(t_x1.get(), t_moe_out.get(), al, t_out.get(), &ws, &e));
194
+ DeviceBuffer wb; if (ws > 0) wb.alloc(ws);
195
+ ACLNN_CHECK(aclnnAdd(wb.get(), ws, e, rt_.stream()));
196
+ aclDestroyScalar(al);
197
+ }
198
+ }
199
+
200
+ void Runner::final_logits_(void* hidden_last, DeviceBuffer& logits_out) {
201
+ // Single-position variant: hidden_last is [1, D], output [1, V].
202
+ final_logits_batch_(hidden_last, 1, logits_out);
203
+ }
204
+
205
+ void Runner::final_logits_batch_(void* hidden, int64_t S, DeviceBuffer& logits_out) {
206
+ const int64_t D = cfg_.hidden_size;
207
+ const int64_t V = cfg_.vocab_size;
208
+
209
+ DeviceBuffer hn(S * D * 2), rstd(S * 4);
210
+ auto t_h = make_contig_tensor(hidden, ACL_BF16, {S, D});
211
+ auto t_hn = make_contig_tensor(hn.get(), ACL_BF16, {S, D});
212
+ auto t_lnw = make_contig_tensor(shared_.final_norm.get(), ACL_BF16, {D});
213
+ auto t_rstd = make_contig_tensor(rstd.get(), ACL_FLOAT, {S});
214
+ rms_norm(rt_.stream(), t_h.get(), t_lnw.get(), cfg_.rms_norm_eps, t_hn.get(), t_rstd.get());
215
+
216
+ logits_out.alloc(S * V * 2);
217
+ auto t_logits = make_contig_tensor(logits_out.get(), ACL_BF16, {S, V});
218
+ linear_hf(rt_.stream(), t_hn.get(), shared_.lm_head.get(), ACL_BF16, V, D, t_logits.get());
219
+ }
220
+
221
+ bool Runner::decode_batch(const int32_t* tokens, int64_t S, DeviceBuffer& all_logits_out) {
222
+ if (S < 1) return false;
223
+ if (past_len_ + S > max_seq_) {
224
+ fprintf(stderr, "runner: decode_batch exceeds max_seq (%ld + %ld > %ld)\n",
225
+ past_len_, S, max_seq_);
226
+ return false;
227
+ }
228
+ const int64_t D = cfg_.hidden_size;
229
+
230
+ ensure_all_scratch_(this, S, cfg_,
231
+ q_sc_, k_sc_, v_sc_, xn_sc_, rstd_sc_, rope_sc_, attn_fias_sc_, attn_out_sc_,
232
+ moe_xn_, moe_rstd_, moe_logits_,
233
+ moe_topk_w_, moe_topk_idx_, moe_row_idx_,
234
+ moe_ex_x_, moe_ex_ri_, moe_tpe_,
235
+ moe_fwd_,
236
+ moe_gate_, moe_up_, moe_down_,
237
+ moe_packed_, moe_weighted_, moe_out_,
238
+ moe_norm_sum_,
239
+ x_buf_a_, x_buf_b_);
240
+
241
+ // Embed S tokens
242
+ DeviceBuffer tok_dev(S * 4);
243
+ ACL_CHECK(aclrtMemcpy(tok_dev.get(), S*4, tokens, S*4, ACL_MEMCPY_HOST_TO_DEVICE));
244
+ auto t_tok = make_contig_tensor(tok_dev.get(), ACL_INT32, {S});
245
+ auto t_embed_w = make_contig_tensor(shared_.embed_tokens.get(), ACL_BF16, {cfg_.vocab_size, D});
246
+
247
+ DeviceBuffer x0(S * D * 2);
248
+ auto t_x0 = make_contig_tensor(x0.get(), ACL_BF16, {S, D});
249
+ index_select(rt_.stream(), t_embed_w.get(), 0, t_tok.get(), t_x0.get());
250
+
251
+ DeviceBuffer xping(S * D * 2), xpong(S * D * 2);
252
+ ACL_CHECK(aclrtMemcpyAsync(xping.get(), S*D*2, x0.get(), S*D*2, ACL_MEMCPY_DEVICE_TO_DEVICE, rt_.stream()));
253
+ void* cur_in = xping.get();
254
+ void* cur_out = xpong.get();
255
+ // batch_decode_mode=true uses proper causal-with-past mask (S × past+S, sparse_mode=0).
256
+ for (int L = 0; L < num_layers_; L++) {
257
+ layer_forward_(L, S, cur_in, cur_out, /*batch_decode_mode=*/past_len_ > 0);
258
+ std::swap(cur_in, cur_out);
259
+ }
260
+ rt_.sync();
261
+
262
+ // Get logits for ALL S positions (not just last)
263
+ final_logits_batch_(cur_in, S, all_logits_out);
264
+ rt_.sync();
265
+
266
+ past_len_ += S;
267
+ return true;
268
+ }
269
+
270
+ bool Runner::prefill(const int32_t* tokens, int64_t S, DeviceBuffer& logits_out) {
271
+ if (S < 1) return false;
272
+ if (past_len_ + S > max_seq_) {
273
+ fprintf(stderr, "runner: prefill exceeds max_seq (%ld + %ld > %ld)\n",
274
+ past_len_, S, max_seq_);
275
+ return false;
276
+ }
277
+
278
+ const int64_t D = cfg_.hidden_size;
279
+ ensure_all_scratch_(this, S, cfg_,
280
+ q_sc_, k_sc_, v_sc_, xn_sc_, rstd_sc_, rope_sc_, attn_fias_sc_, attn_out_sc_,
281
+ moe_xn_, moe_rstd_, moe_logits_,
282
+ moe_topk_w_, moe_topk_idx_, moe_row_idx_,
283
+ moe_ex_x_, moe_ex_ri_, moe_tpe_,
284
+ moe_fwd_,
285
+ moe_gate_, moe_up_, moe_down_,
286
+ moe_packed_, moe_weighted_, moe_out_,
287
+ moe_norm_sum_,
288
+ x_buf_a_, x_buf_b_);
289
+
290
+ // Embed
291
+ DeviceBuffer tok_dev(S * 4);
292
+ ACL_CHECK(aclrtMemcpy(tok_dev.get(), S*4, tokens, S*4, ACL_MEMCPY_HOST_TO_DEVICE));
293
+ auto t_tok = make_contig_tensor(tok_dev.get(), ACL_INT32, {S});
294
+ auto t_embed_w = make_contig_tensor(shared_.embed_tokens.get(), ACL_BF16, {cfg_.vocab_size, D});
295
+
296
+ DeviceBuffer x0(S * D * 2);
297
+ auto t_x0 = make_contig_tensor(x0.get(), ACL_BF16, {S, D});
298
+ index_select(rt_.stream(), t_embed_w.get(), 0, t_tok.get(), t_x0.get());
299
+
300
+ // Layer chain: ping-pong between two buffers
301
+ DeviceBuffer xping(S * D * 2), xpong(S * D * 2);
302
+ ACL_CHECK(aclrtMemcpyAsync(xping.get(), S*D*2, x0.get(), S*D*2, ACL_MEMCPY_DEVICE_TO_DEVICE, rt_.stream()));
303
+
304
+ void* cur_in = xping.get();
305
+ void* cur_out = xpong.get();
306
+ for (int L = 0; L < num_layers_; L++) {
307
+ layer_forward_(L, S, cur_in, cur_out);
308
+ std::swap(cur_in, cur_out);
309
+ }
310
+ rt_.sync();
311
+
312
+ // Take last position's hidden → final_logits
313
+ DeviceBuffer last(1 * D * 2);
314
+ ACL_CHECK(aclrtMemcpy(last.get(), 1*D*2,
315
+ (char*)cur_in + (S - 1) * D * 2, 1*D*2,
316
+ ACL_MEMCPY_DEVICE_TO_DEVICE));
317
+ final_logits_(last.get(), logits_out);
318
+ rt_.sync();
319
+
320
+ past_len_ += S;
321
+ return true;
322
+ }
323
+
324
+ bool Runner::decode(int32_t token, DeviceBuffer& logits_out) {
325
+ const int64_t D = cfg_.hidden_size;
326
+ if (past_len_ + 1 > max_seq_) {
327
+ fprintf(stderr, "runner: decode exceeds max_seq\n");
328
+ return false;
329
+ }
330
+
331
+ const int64_t S = 1;
332
+ ensure_all_scratch_(this, S, cfg_,
333
+ q_sc_, k_sc_, v_sc_, xn_sc_, rstd_sc_, rope_sc_, attn_fias_sc_, attn_out_sc_,
334
+ moe_xn_, moe_rstd_, moe_logits_,
335
+ moe_topk_w_, moe_topk_idx_, moe_row_idx_,
336
+ moe_ex_x_, moe_ex_ri_, moe_tpe_,
337
+ moe_fwd_,
338
+ moe_gate_, moe_up_, moe_down_,
339
+ moe_packed_, moe_weighted_, moe_out_,
340
+ moe_norm_sum_,
341
+ x_buf_a_, x_buf_b_);
342
+
343
+ DeviceBuffer tok_dev(1 * 4);
344
+ ACL_CHECK(aclrtMemcpy(tok_dev.get(), 4, &token, 4, ACL_MEMCPY_HOST_TO_DEVICE));
345
+ auto t_tok = make_contig_tensor(tok_dev.get(), ACL_INT32, {1});
346
+ auto t_embed_w = make_contig_tensor(shared_.embed_tokens.get(), ACL_BF16, {cfg_.vocab_size, D});
347
+
348
+ auto t0 = std::chrono::steady_clock::now();
349
+
350
+ DeviceBuffer x0(1 * D * 2);
351
+ auto t_x0 = make_contig_tensor(x0.get(), ACL_BF16, {1, D});
352
+ index_select(rt_.stream(), t_embed_w.get(), 0, t_tok.get(), t_x0.get());
353
+
354
+ DeviceBuffer xping(1 * D * 2), xpong(1 * D * 2);
355
+ ACL_CHECK(aclrtMemcpyAsync(xping.get(), 1*D*2, x0.get(), 1*D*2, ACL_MEMCPY_DEVICE_TO_DEVICE, rt_.stream()));
356
+ if (profile_enabled) { ACL_CHECK(aclrtSynchronizeStream(rt_.stream())); }
357
+ auto t1 = std::chrono::steady_clock::now();
358
+
359
+ void* cur_in = xping.get();
360
+ void* cur_out = xpong.get();
361
+ for (int L = 0; L < num_layers_; L++) {
362
+ layer_forward_(L, 1, cur_in, cur_out);
363
+ std::swap(cur_in, cur_out);
364
+ }
365
+ rt_.sync();
366
+ auto t2 = std::chrono::steady_clock::now();
367
+
368
+ final_logits_(cur_in, logits_out);
369
+ rt_.sync();
370
+ auto t3 = std::chrono::steady_clock::now();
371
+
372
+ if (profile_enabled) {
373
+ using ms = std::chrono::duration<double, std::milli>;
374
+ t_embed_ms += ms(t1 - t0).count();
375
+ t_layers_ms += ms(t2 - t1).count();
376
+ t_final_ms += ms(t3 - t2).count();
377
+ profile_calls++;
378
+ }
379
+
380
+ past_len_ += 1;
381
+ return true;
382
+ }
383
+
384
+ void Runner::build_batch_decode_mask_(int64_t S) {
385
+ int64_t kv_len = past_len_ + S;
386
+ size_t bytes = (size_t)S * kv_len; // bool = 1 byte
387
+ if (batch_mask_dev_.size < bytes) batch_mask_dev_.alloc(bytes);
388
+ std::vector<uint8_t> h_mask(bytes, 0);
389
+ for (int64_t i = 0; i < S; i++) {
390
+ // Row i: positions j ≤ past_len_+i are visible (0), j > past_len_+i are masked (1).
391
+ for (int64_t j = past_len_ + i + 1; j < kv_len; j++) {
392
+ h_mask[i * kv_len + j] = 1;
393
+ }
394
+ }
395
+ ACL_CHECK(aclrtMemcpy(batch_mask_dev_.get(), bytes, h_mask.data(), bytes,
396
+ ACL_MEMCPY_HOST_TO_DEVICE));
397
+ }
398
+
399
+ void Runner::warmup(int iterations) {
400
+ if (num_layers_ == 0) return;
401
+ int64_t saved_past = past_len_;
402
+ past_len_ = 0;
403
+ int32_t dummy_tok = 0; // token id 0, valid for Qwen3 (bos)
404
+ DeviceBuffer dummy_logits;
405
+ for (int i = 0; i < iterations; i++) {
406
+ past_len_ = 0;
407
+ if (!decode(dummy_tok, dummy_logits)) break;
408
+ }
409
+ past_len_ = saved_past;
410
+ fprintf(stderr, "[runner] warmup: %d iterations done\n", iterations);
411
+ }
412
+
413
+ void Runner::print_profile_summary() const {
414
+ if (!profile_enabled || profile_calls == 0) return;
415
+ double total = t_embed_ms + t_layers_ms + t_final_ms;
416
+ fprintf(stderr, "\n=== Runner profile (%ld decode calls) ===\n", profile_calls);
417
+ fprintf(stderr, " phase total_ms avg_ms/call pct\n");
418
+ fprintf(stderr, " embed %8.1f %10.3f %5.1f%%\n",
419
+ t_embed_ms, t_embed_ms / profile_calls, 100.0 * t_embed_ms / total);
420
+ fprintf(stderr, " layers (x%d) %8.1f %10.3f %5.1f%% → %.3f ms/layer/call\n",
421
+ num_layers_, t_layers_ms, t_layers_ms / profile_calls,
422
+ 100.0 * t_layers_ms / total,
423
+ t_layers_ms / profile_calls / num_layers_);
424
+ fprintf(stderr, " final+lm_hd %8.1f %10.3f %5.1f%%\n",
425
+ t_final_ms, t_final_ms / profile_calls, 100.0 * t_final_ms / total);
426
+ fprintf(stderr, " total %8.1f %10.3f 100.0%%\n",
427
+ total, total / profile_calls);
428
+ }
src/safetensors_loader.cpp ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "safetensors_loader.h"
2
+
3
+ #include <fcntl.h>
4
+ #include <sys/mman.h>
5
+ #include <sys/stat.h>
6
+ #include <unistd.h>
7
+
8
+ #include <cstdio>
9
+ #include <cstring>
10
+ #include <fstream>
11
+ #include <sstream>
12
+
13
+ #include "json.hpp"
14
+
15
+ using json = nlohmann::json;
16
+
17
+ SafetensorsLoader::SafetensorsLoader() = default;
18
+
19
+ SafetensorsLoader::~SafetensorsLoader() {
20
+ for (auto& s : shards_) {
21
+ if (s.mmap_ptr) munmap(s.mmap_ptr, s.mmap_size);
22
+ if (s.fd >= 0) close(s.fd);
23
+ }
24
+ }
25
+
26
+ bool SafetensorsLoader::open(const std::string& dir) {
27
+ model_dir_ = dir;
28
+
29
+ // 1. Parse index.json to discover shard files
30
+ std::string idx_path = dir + "/model.safetensors.index.json";
31
+ std::ifstream idx_file(idx_path);
32
+ if (!idx_file) {
33
+ // Fallback: single-file model
34
+ std::string single = dir + "/model.safetensors";
35
+ std::ifstream f(single);
36
+ if (!f) {
37
+ fprintf(stderr, "SafetensorsLoader: neither index.json nor model.safetensors found in %s\n", dir.c_str());
38
+ return false;
39
+ }
40
+ shards_.push_back({single});
41
+ return parse_shard_header_(0);
42
+ }
43
+
44
+ json idx;
45
+ try { idx_file >> idx; } catch (std::exception& e) {
46
+ fprintf(stderr, "SafetensorsLoader: bad index.json: %s\n", e.what());
47
+ return false;
48
+ }
49
+ if (!idx.contains("weight_map")) {
50
+ fprintf(stderr, "SafetensorsLoader: index.json missing weight_map\n");
51
+ return false;
52
+ }
53
+
54
+ // Collect unique shard filenames (preserving discovery order).
55
+ std::map<std::string, int> shard_name_to_id;
56
+ for (auto& [name, file] : idx["weight_map"].items()) {
57
+ std::string shard_name = file.get<std::string>();
58
+ if (shard_name_to_id.count(shard_name) == 0) {
59
+ int id = (int)shards_.size();
60
+ shard_name_to_id[shard_name] = id;
61
+ shards_.push_back({dir + "/" + shard_name});
62
+ }
63
+ }
64
+
65
+ // 2. Parse header of each shard to discover tensor offsets
66
+ for (int i = 0; i < (int)shards_.size(); i++) {
67
+ if (!parse_shard_header_(i)) {
68
+ fprintf(stderr, "SafetensorsLoader: failed to parse shard %s\n", shards_[i].path.c_str());
69
+ return false;
70
+ }
71
+ }
72
+
73
+ return true;
74
+ }
75
+
76
+ bool SafetensorsLoader::parse_shard_header_(int shard_id) {
77
+ ShardFile& sh = shards_[shard_id];
78
+ std::ifstream f(sh.path, std::ios::binary);
79
+ if (!f) return false;
80
+
81
+ // Read 8-byte little-endian header length
82
+ uint64_t header_len = 0;
83
+ f.read((char*)&header_len, 8);
84
+ if (!f) return false;
85
+
86
+ std::string header(header_len, '\0');
87
+ f.read(header.data(), header_len);
88
+ if (!f) return false;
89
+
90
+ sh.data_base = 8 + header_len;
91
+
92
+ json j;
93
+ try { j = json::parse(header); } catch (std::exception& e) {
94
+ fprintf(stderr, "SafetensorsLoader: bad shard header JSON in %s: %s\n", sh.path.c_str(), e.what());
95
+ return false;
96
+ }
97
+
98
+ for (auto it = j.begin(); it != j.end(); ++it) {
99
+ const std::string& name = it.key();
100
+ if (name == "__metadata__") continue;
101
+ const auto& entry = it.value();
102
+
103
+ TensorMeta m;
104
+ m.name = name;
105
+ m.dtype = entry["dtype"].get<std::string>();
106
+ for (auto& d : entry["shape"]) m.shape.push_back(d.get<int64_t>());
107
+ const auto& offs = entry["data_offsets"];
108
+ size_t begin = offs[0].get<size_t>();
109
+ size_t end = offs[1].get<size_t>();
110
+ m.offset = sh.data_base + begin;
111
+ m.nbytes = end - begin;
112
+ m.shard_id = shard_id;
113
+
114
+ tensors_[name] = std::move(m);
115
+ }
116
+
117
+ return true;
118
+ }
119
+
120
+ bool SafetensorsLoader::mmap_shard_(int shard_id) {
121
+ ShardFile& sh = shards_[shard_id];
122
+ if (sh.mmap_ptr) return true;
123
+
124
+ sh.fd = ::open(sh.path.c_str(), O_RDONLY);
125
+ if (sh.fd < 0) {
126
+ perror("open");
127
+ return false;
128
+ }
129
+ struct stat st;
130
+ if (fstat(sh.fd, &st) != 0) return false;
131
+ sh.mmap_size = st.st_size;
132
+
133
+ sh.mmap_ptr = mmap(nullptr, sh.mmap_size, PROT_READ, MAP_PRIVATE, sh.fd, 0);
134
+ if (sh.mmap_ptr == MAP_FAILED) {
135
+ perror("mmap");
136
+ sh.mmap_ptr = nullptr;
137
+ return false;
138
+ }
139
+ return true;
140
+ }
141
+
142
+ const TensorMeta* SafetensorsLoader::get(const std::string& name) const {
143
+ auto it = tensors_.find(name);
144
+ if (it == tensors_.end()) return nullptr;
145
+ return &it->second;
146
+ }
147
+
148
+ const void* SafetensorsLoader::data_ptr(const TensorMeta& m) {
149
+ if (m.shard_id < 0 || (size_t)m.shard_id >= shards_.size()) return nullptr;
150
+ if (!mmap_shard_(m.shard_id)) return nullptr;
151
+ ShardFile& sh = shards_[m.shard_id];
152
+ return (const char*)sh.mmap_ptr + m.offset;
153
+ }
154
+
155
+ const void* SafetensorsLoader::data_ptr(const std::string& name) {
156
+ const auto* m = get(name);
157
+ if (!m) return nullptr;
158
+ return data_ptr(*m);
159
+ }
160
+
161
+ std::vector<std::string> SafetensorsLoader::list_tensor_names() const {
162
+ std::vector<std::string> out;
163
+ out.reserve(tensors_.size());
164
+ for (auto& [k, v] : tensors_) out.push_back(k);
165
+ return out;
166
+ }
167
+
168
+ size_t SafetensorsLoader::total_bytes() const {
169
+ size_t sum = 0;
170
+ for (auto& [k, v] : tensors_) sum += v.nbytes;
171
+ return sum;
172
+ }
src/tokenizer.cpp ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "tokenizer.h"
2
+
3
+ #include <array>
4
+ #include <cstdint>
5
+ #include <cstdio>
6
+ #include <cstdlib>
7
+ #include <cstring>
8
+ #include <fstream>
9
+ #include <memory>
10
+ #include <sstream>
11
+ #include <unistd.h>
12
+
13
+ bool Tokenizer::load(const std::string& vocab_bin_path) {
14
+ std::ifstream f(vocab_bin_path, std::ios::binary);
15
+ if (!f) {
16
+ fprintf(stderr, "Tokenizer: cannot open %s\n", vocab_bin_path.c_str());
17
+ return false;
18
+ }
19
+ uint32_t num;
20
+ f.read((char*)&num, 4);
21
+ if (!f) return false;
22
+ id_to_bytes_.resize(num);
23
+ for (uint32_t i = 0; i < num; i++) {
24
+ uint32_t len;
25
+ f.read((char*)&len, 4);
26
+ if (!f) return false;
27
+ id_to_bytes_[i].resize(len);
28
+ if (len > 0) f.read(id_to_bytes_[i].data(), len);
29
+ }
30
+ return true;
31
+ }
32
+
33
+ std::string Tokenizer::decode(int id) const {
34
+ if (id < 0 || (size_t)id >= id_to_bytes_.size()) return "";
35
+ return id_to_bytes_[id];
36
+ }
37
+
38
+ std::string Tokenizer::decode(const std::vector<int>& ids) const {
39
+ std::string out;
40
+ for (int id : ids) out += decode(id);
41
+ return out;
42
+ }
43
+
44
+ std::vector<int> Tokenizer::encode_via_python(const std::string& model_dir,
45
+ const std::string& prompt,
46
+ bool apply_chat_template) const {
47
+ // Call python subprocess to tokenize. Embed prompt via stdin to avoid shell-escape bugs.
48
+ std::string cmd;
49
+ // Set QWEN3_PYENV_INIT to override the Python env activation sequence (e.g., "source /opt/my_env/activate && ")
50
+ // Default assumes conda at ~/miniconda3 with env 'qwen3' and Ascend toolkit installed.
51
+ if (const char* init = std::getenv("QWEN3_PYENV_INIT")) {
52
+ cmd += init;
53
+ } else {
54
+ cmd += "source ${HOME}/miniconda3/etc/profile.d/conda.sh 2>/dev/null && ";
55
+ cmd += "conda activate qwen3 2>/dev/null || true; ";
56
+ cmd += "source /usr/local/Ascend/ascend-toolkit/set_env.sh 2>/dev/null || true; ";
57
+ }
58
+ cmd += "python3 -c \"";
59
+ cmd += "import sys, json;";
60
+ cmd += "from transformers import AutoTokenizer;";
61
+ cmd += "t = AutoTokenizer.from_pretrained('" + model_dir + "');";
62
+ cmd += "p = sys.stdin.read();";
63
+ if (apply_chat_template) {
64
+ cmd += "msg = [{'role': 'user', 'content': p}];";
65
+ cmd += "ids = t.apply_chat_template(msg, add_generation_prompt=True);";
66
+ } else {
67
+ cmd += "ids = t.encode(p);";
68
+ }
69
+ cmd += "print(' '.join(str(i) for i in ids));";
70
+ cmd += "\"";
71
+
72
+ // popen with stdin: use the two-pipe dance via temp file for safety
73
+ char tmpl[] = "/tmp/lca_prompt_XXXXXX";
74
+ int fd = mkstemp(tmpl);
75
+ if (fd < 0) { perror("mkstemp"); return {}; }
76
+ write(fd, prompt.data(), prompt.size());
77
+ close(fd);
78
+
79
+ std::string full = cmd + " < " + tmpl + " 2>/dev/null";
80
+ FILE* pipe = popen(full.c_str(), "r");
81
+ if (!pipe) { perror("popen"); unlink(tmpl); return {}; }
82
+
83
+ std::string out;
84
+ char buf[4096];
85
+ while (size_t n = fread(buf, 1, sizeof(buf), pipe)) out.append(buf, n);
86
+ pclose(pipe);
87
+ unlink(tmpl);
88
+
89
+ std::vector<int> ids;
90
+ std::istringstream iss(out);
91
+ int x;
92
+ while (iss >> x) ids.push_back(x);
93
+ return ids;
94
+ }
95
+
96
+ // Shell-quote a string for embedding in a JSON string (escape ", \, control chars).
97
+ static std::string json_escape(const std::string& s) {
98
+ std::string out;
99
+ out.reserve(s.size() + 8);
100
+ for (char c : s) {
101
+ switch (c) {
102
+ case '"': out += "\\\""; break;
103
+ case '\\': out += "\\\\"; break;
104
+ case '\n': out += "\\n"; break;
105
+ case '\r': out += "\\r"; break;
106
+ case '\t': out += "\\t"; break;
107
+ default:
108
+ if ((unsigned char)c < 0x20) {
109
+ char buf[8];
110
+ snprintf(buf, sizeof(buf), "\\u%04x", (unsigned char)c);
111
+ out += buf;
112
+ } else {
113
+ out += c;
114
+ }
115
+ }
116
+ }
117
+ return out;
118
+ }
119
+
120
+ std::vector<int> Tokenizer::encode_conversation_via_python(
121
+ const std::string& model_dir,
122
+ const std::vector<std::pair<std::string, std::string>>& conversation,
123
+ bool add_generation_prompt) const
124
+ {
125
+ // Build JSON array of messages. Pass via stdin to avoid shell-escape issues.
126
+ std::string json_msgs = "[";
127
+ for (size_t i = 0; i < conversation.size(); i++) {
128
+ if (i > 0) json_msgs += ",";
129
+ json_msgs += "{\"role\":\"" + json_escape(conversation[i].first) + "\",";
130
+ json_msgs += "\"content\":\"" + json_escape(conversation[i].second) + "\"}";
131
+ }
132
+ json_msgs += "]";
133
+
134
+ std::string cmd;
135
+ // Set QWEN3_PYENV_INIT to override the Python env activation sequence (e.g., "source /opt/my_env/activate && ")
136
+ // Default assumes conda at ~/miniconda3 with env 'qwen3' and Ascend toolkit installed.
137
+ if (const char* init = std::getenv("QWEN3_PYENV_INIT")) {
138
+ cmd += init;
139
+ } else {
140
+ cmd += "source ${HOME}/miniconda3/etc/profile.d/conda.sh 2>/dev/null && ";
141
+ cmd += "conda activate qwen3 2>/dev/null || true; ";
142
+ cmd += "source /usr/local/Ascend/ascend-toolkit/set_env.sh 2>/dev/null || true; ";
143
+ }
144
+ cmd += "python3 -c \"";
145
+ cmd += "import sys, json;";
146
+ cmd += "from transformers import AutoTokenizer;";
147
+ cmd += "t = AutoTokenizer.from_pretrained('" + model_dir + "');";
148
+ cmd += "msgs = json.loads(sys.stdin.read());";
149
+ cmd += "ids = t.apply_chat_template(msgs, add_generation_prompt=";
150
+ cmd += add_generation_prompt ? "True" : "False";
151
+ cmd += ");";
152
+ cmd += "print(' '.join(str(i) for i in ids));";
153
+ cmd += "\"";
154
+
155
+ char tmpl[] = "/tmp/lca_conv_XXXXXX";
156
+ int fd = mkstemp(tmpl);
157
+ if (fd < 0) { perror("mkstemp"); return {}; }
158
+ write(fd, json_msgs.data(), json_msgs.size());
159
+ close(fd);
160
+
161
+ std::string full = cmd + " < " + tmpl + " 2>/dev/null";
162
+ FILE* pipe = popen(full.c_str(), "r");
163
+ if (!pipe) { perror("popen"); unlink(tmpl); return {}; }
164
+
165
+ std::string out;
166
+ char buf[4096];
167
+ while (size_t n = fread(buf, 1, sizeof(buf), pipe)) out.append(buf, n);
168
+ pclose(pipe);
169
+ unlink(tmpl);
170
+
171
+ std::vector<int> ids;
172
+ std::istringstream iss(out);
173
+ int x;
174
+ while (iss >> x) ids.push_back(x);
175
+ return ids;
176
+ }
tests/hello_acl.cpp ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // hello_acl.cpp — smoke test: aclInit + device + stream + simple tensor + aclnnAdd
2
+ #include "acl_common.h"
3
+ #include <aclnnop/aclnn_add.h>
4
+ #include <cstdio>
5
+ #include <vector>
6
+
7
+ int main() {
8
+ ACL_CHECK(aclInit(nullptr));
9
+ ACL_CHECK(aclrtSetDevice(0));
10
+ aclrtContext ctx;
11
+ ACL_CHECK(aclrtCreateContext(&ctx, 0));
12
+ aclrtStream stream;
13
+ ACL_CHECK(aclrtCreateStream(&stream));
14
+
15
+ // Tiny test: a = [1, 2, 3, 4] f32, b = [10, 20, 30, 40] f32, out = a + b
16
+ const int64_t N = 4;
17
+ std::vector<float> a_host = {1.0f, 2.0f, 3.0f, 4.0f};
18
+ std::vector<float> b_host = {10.0f, 20.0f, 30.0f, 40.0f};
19
+ std::vector<float> out_host(N, 0.0f);
20
+
21
+ DeviceBuffer a_dev(N * sizeof(float));
22
+ DeviceBuffer b_dev(N * sizeof(float));
23
+ DeviceBuffer out_dev(N * sizeof(float));
24
+
25
+ ACL_CHECK(aclrtMemcpy(a_dev.get(), N * 4, a_host.data(), N * 4, ACL_MEMCPY_HOST_TO_DEVICE));
26
+ ACL_CHECK(aclrtMemcpy(b_dev.get(), N * 4, b_host.data(), N * 4, ACL_MEMCPY_HOST_TO_DEVICE));
27
+
28
+ auto a_t = make_contig_tensor(a_dev.get(), ACL_FLOAT, {N});
29
+ auto b_t = make_contig_tensor(b_dev.get(), ACL_FLOAT, {N});
30
+ auto out_t = make_contig_tensor(out_dev.get(), ACL_FLOAT, {N});
31
+
32
+ // aclnnAdd: out = a + alpha * b
33
+ float alpha_val = 1.0f;
34
+ aclScalar* alpha = aclCreateScalar(&alpha_val, ACL_FLOAT);
35
+
36
+ uint64_t ws_size = 0;
37
+ aclOpExecutor* executor = nullptr;
38
+ ACLNN_CHECK(aclnnAddGetWorkspaceSize(a_t.get(), b_t.get(), alpha, out_t.get(), &ws_size, &executor));
39
+
40
+ DeviceBuffer ws;
41
+ if (ws_size > 0) ws.alloc(ws_size);
42
+ ACLNN_CHECK(aclnnAdd(ws.get(), ws_size, executor, stream));
43
+
44
+ ACL_CHECK(aclrtSynchronizeStream(stream));
45
+
46
+ ACL_CHECK(aclrtMemcpy(out_host.data(), N * 4, out_dev.get(), N * 4, ACL_MEMCPY_DEVICE_TO_HOST));
47
+
48
+ printf("hello_acl: ");
49
+ for (int i = 0; i < N; i++) printf("%.1f ", out_host[i]);
50
+ printf("\n");
51
+
52
+ bool ok = (out_host[0] == 11.0f && out_host[1] == 22.0f &&
53
+ out_host[2] == 33.0f && out_host[3] == 44.0f);
54
+ printf(ok ? "PASS\n" : "FAIL\n");
55
+
56
+ aclDestroyScalar(alpha);
57
+ ACL_CHECK(aclrtDestroyStream(stream));
58
+ ACL_CHECK(aclrtDestroyContext(ctx));
59
+ ACL_CHECK(aclrtResetDevice(0));
60
+ aclFinalize();
61
+ return ok ? 0 : 1;
62
+ }
tests/test_attention_decode.cpp ADDED
@@ -0,0 +1,319 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // test_attention_decode.cpp — validates single-layer attention with KV cache.
2
+ //
3
+ // Strategy: compare two paths yielding the same pos-5 attention output:
4
+ // Path A (reference): prefill 6 tokens in one shot → attn_out[5]
5
+ // Path B (decode): prefill 5 tokens → K/V cache; decode 6th token via cache → attn_out_decode[0]
6
+ //
7
+ // The two outputs should match within BF16 precision.
8
+ #include "acl_common.h"
9
+ #include "acl_runtime.h"
10
+ #include "aclnn_ops.h"
11
+ #include "device_weights.h"
12
+ #include "model_config.h"
13
+ #include "rope.h"
14
+ #include "safetensors_loader.h"
15
+
16
+ #include <cmath>
17
+ #include <cstdio>
18
+ #include <cstring>
19
+ #include <fstream>
20
+ #include <vector>
21
+
22
+ static float bf16_to_float(uint16_t x) {
23
+ uint32_t u = (uint32_t)x << 16; float f; std::memcpy(&f, &u, 4); return f;
24
+ }
25
+ static uint16_t float_to_bf16(float x) {
26
+ uint32_t u; std::memcpy(&u, &x, 4);
27
+ return (uint16_t)((u + 0x7FFF + ((u >> 16) & 1)) >> 16);
28
+ }
29
+ static std::vector<uint8_t> read_file(const std::string& p) {
30
+ std::ifstream f(p, std::ios::binary | std::ios::ate); size_t s = f.tellg();
31
+ f.seekg(0); std::vector<uint8_t> v(s); f.read((char*)v.data(), s); return v;
32
+ }
33
+
34
+ // Fill cos/sin tables for a range of positions [p0, p0+L). HF layout: half-half.
35
+ static void fill_cos_sin(std::vector<uint16_t>& cos_h, std::vector<uint16_t>& sin_h,
36
+ int64_t p0, int64_t L, int64_t Dh, float theta) {
37
+ cos_h.resize(L * Dh); sin_h.resize(L * Dh);
38
+ int64_t half = Dh / 2;
39
+ for (int64_t s = 0; s < L; s++) {
40
+ for (int64_t d = 0; d < Dh; d++) {
41
+ int64_t pair = (d < half) ? d : (d - half);
42
+ float theta_pair = 1.0f / std::pow(theta, (2.0f * pair) / Dh);
43
+ float angle = (float)(p0 + s) * theta_pair;
44
+ cos_h[s * Dh + d] = float_to_bf16(std::cos(angle));
45
+ sin_h[s * Dh + d] = float_to_bf16(std::sin(angle));
46
+ }
47
+ }
48
+ }
49
+
50
+ int main() {
51
+ const std::string model_dir = "/path/to/Qwen3-235B-A22B-Instruct-2507-BF16";
52
+ const std::string data_dir = "tests/attn_data";
53
+
54
+ ModelConfig cfg;
55
+ if (!cfg.load_from_json(model_dir + "/config.json")) return 1;
56
+ cfg.compute_derived(1, 0);
57
+ const int64_t D = cfg.hidden_size;
58
+ const int64_t Hq = cfg.num_attention_heads;
59
+ const int64_t Hkv = cfg.num_key_value_heads;
60
+ const int64_t Dh = cfg.head_dim;
61
+ const int64_t Q_DIM = Hq * Dh;
62
+ const int64_t KV_DIM = Hkv * Dh;
63
+ const double scale = 1.0 / std::sqrt((double)Dh);
64
+ const double eps = cfg.rms_norm_eps;
65
+ const float theta = cfg.rope_theta;
66
+
67
+ SafetensorsLoader st;
68
+ if (!st.open(model_dir)) return 1;
69
+ AclRuntime rt;
70
+ rt.init(0);
71
+
72
+ DeviceWeightsLoader dw(st, cfg);
73
+ SharedWeights shared;
74
+ LayerAttnWeights attn;
75
+ printf("Loading weights...\n");
76
+ if (!dw.load_shared(shared)) return 1;
77
+ if (!dw.load_attention(0, attn)) return 1;
78
+
79
+ // ---- Load 5 prefill tokens + use token[5]=random as "6th" decoded token ----
80
+ auto tok_raw = read_file(data_dir + "/token_ids.bin");
81
+ int32_t S_prefill = *(int32_t*)tok_raw.data();
82
+ if (S_prefill < 5) { fprintf(stderr, "need >=5 tokens\n"); return 1; }
83
+ std::vector<int32_t> tokens(S_prefill);
84
+ std::memcpy(tokens.data(), tok_raw.data() + 4, S_prefill * 4);
85
+
86
+ // Build 6-token sequence (reuse first 5; pick a 6th token id — use token 0 as a simple choice)
87
+ const int64_t S6 = 6;
88
+ const int64_t S5 = 5;
89
+ std::vector<int32_t> tok6(S6);
90
+ for (int i = 0; i < S5; i++) tok6[i] = tokens[i];
91
+ tok6[5] = tokens[0]; // any token works for cross-consistency test
92
+ printf("tokens6=["); for (auto t : tok6) printf("%d,", t); printf("]\n");
93
+
94
+ // ---- Causal mask (2048x2048, sparse_mode=3) shared across both paths ----
95
+ const int64_t MASK = 2048;
96
+ DeviceBuffer mask_dev(MASK * MASK);
97
+ std::vector<uint8_t> mask_host(MASK * MASK, 0);
98
+ for (int i = 0; i < MASK; i++)
99
+ for (int j = i+1; j < MASK; j++)
100
+ mask_host[i*MASK + j] = 1;
101
+ ACL_CHECK(aclrtMemcpy(mask_dev.get(), MASK*MASK, mask_host.data(), MASK*MASK, ACL_MEMCPY_HOST_TO_DEVICE));
102
+ auto t_mask = make_contig_tensor(mask_dev.get(), ACL_BOOL, {1, 1, MASK, MASK});
103
+
104
+ // =========================================================================
105
+ // PATH A: 6-token prefill (reference)
106
+ // =========================================================================
107
+ printf("\n[Path A] 6-token prefill reference\n");
108
+
109
+ DeviceBuffer tokA_dev(S6 * 4);
110
+ ACL_CHECK(aclrtMemcpy(tokA_dev.get(), S6*4, tok6.data(), S6*4, ACL_MEMCPY_HOST_TO_DEVICE));
111
+ auto t_tokA = make_contig_tensor(tokA_dev.get(), ACL_INT32, {S6});
112
+ auto t_embed_w = make_contig_tensor(shared.embed_tokens.get(), ACL_BF16, {cfg.vocab_size, D});
113
+
114
+ DeviceBuffer xA_dev(S6 * D * 2);
115
+ auto t_xA = make_contig_tensor(xA_dev.get(), ACL_BF16, {S6, D});
116
+ index_select(rt.stream(), t_embed_w.get(), 0, t_tokA.get(), t_xA.get());
117
+ rt.sync();
118
+
119
+ DeviceBuffer xnA_dev(S6 * D * 2);
120
+ DeviceBuffer rstdA_dev(S6 * 4);
121
+ auto t_xnA = make_contig_tensor(xnA_dev.get(), ACL_BF16, {S6, D});
122
+ auto t_ln_w = make_contig_tensor(attn.input_layernorm.get(), ACL_BF16, {D});
123
+ auto t_rstdA = make_contig_tensor(rstdA_dev.get(), ACL_FLOAT, {S6});
124
+ rms_norm(rt.stream(), t_xA.get(), t_ln_w.get(), eps, t_xnA.get(), t_rstdA.get());
125
+
126
+ DeviceBuffer qA_dev(S6 * Q_DIM * 2);
127
+ DeviceBuffer kA_dev(S6 * KV_DIM * 2);
128
+ DeviceBuffer vA_dev(S6 * KV_DIM * 2);
129
+ auto t_qA = make_contig_tensor(qA_dev.get(), ACL_BF16, {S6, Q_DIM});
130
+ auto t_kA = make_contig_tensor(kA_dev.get(), ACL_BF16, {S6, KV_DIM});
131
+ auto t_vA = make_contig_tensor(vA_dev.get(), ACL_BF16, {S6, KV_DIM});
132
+ linear_hf(rt.stream(), t_xnA.get(), attn.q_proj.get(), ACL_BF16, Q_DIM, D, t_qA.get());
133
+ linear_hf(rt.stream(), t_xnA.get(), attn.k_proj.get(), ACL_BF16, KV_DIM, D, t_kA.get());
134
+ linear_hf(rt.stream(), t_xnA.get(), attn.v_proj.get(), ACL_BF16, KV_DIM, D, t_vA.get());
135
+
136
+ // Per-head norm
137
+ auto t_qA_4d = make_contig_tensor(qA_dev.get(), ACL_BF16, {1, S6, Hq, Dh});
138
+ auto t_kA_4d = make_contig_tensor(kA_dev.get(), ACL_BF16, {1, S6, Hkv, Dh});
139
+ auto t_qn_w = make_contig_tensor(attn.q_norm.get(), ACL_BF16, {Dh});
140
+ auto t_kn_w = make_contig_tensor(attn.k_norm.get(), ACL_BF16, {Dh});
141
+ DeviceBuffer rstd_qA(S6 * Hq * 4), rstd_kA(S6 * Hkv * 4);
142
+ auto t_rstd_qA = make_contig_tensor(rstd_qA.get(), ACL_FLOAT, {1, S6, Hq});
143
+ auto t_rstd_kA = make_contig_tensor(rstd_kA.get(), ACL_FLOAT, {1, S6, Hkv});
144
+ rms_norm(rt.stream(), t_qA_4d.get(), t_qn_w.get(), eps, t_qA_4d.get(), t_rstd_qA.get());
145
+ rms_norm(rt.stream(), t_kA_4d.get(), t_kn_w.get(), eps, t_kA_4d.get(), t_rstd_kA.get());
146
+
147
+ // RoPE for positions 0..5
148
+ std::vector<uint16_t> cosA_h, sinA_h;
149
+ fill_cos_sin(cosA_h, sinA_h, 0, S6, Dh, theta);
150
+ DeviceBuffer cosA_dev(S6 * Dh * 2), sinA_dev(S6 * Dh * 2);
151
+ ACL_CHECK(aclrtMemcpy(cosA_dev.get(), S6*Dh*2, cosA_h.data(), S6*Dh*2, ACL_MEMCPY_HOST_TO_DEVICE));
152
+ ACL_CHECK(aclrtMemcpy(sinA_dev.get(), S6*Dh*2, sinA_h.data(), S6*Dh*2, ACL_MEMCPY_HOST_TO_DEVICE));
153
+ DeviceBuffer ropeA_scratch(1 * S6 * Hq * Dh * 2);
154
+ apply_rope_manual(rt.stream(), qA_dev.get(), 1, S6, Hq, Dh, kA_dev.get(), Hkv,
155
+ cosA_dev.get(), sinA_dev.get(), ropeA_scratch.get());
156
+
157
+ auto t_qA_bsh = make_contig_tensor(qA_dev.get(), ACL_BF16, {1, S6, Q_DIM});
158
+ auto t_kA_bsh = make_contig_tensor(kA_dev.get(), ACL_BF16, {1, S6, KV_DIM});
159
+ auto t_vA_bsh = make_contig_tensor(vA_dev.get(), ACL_BF16, {1, S6, KV_DIM});
160
+
161
+ DeviceBuffer attnA_out(1 * S6 * Q_DIM * 2);
162
+ auto t_attnA_out = make_contig_tensor(attnA_out.get(), ACL_BF16, {1, S6, Q_DIM});
163
+ fused_infer_attention_score(
164
+ rt.stream(), t_qA_bsh.get(), t_kA_bsh.get(), t_vA_bsh.get(),
165
+ t_mask.get(), {S6}, {S6}, Hq, Hkv, scale, 3, t_attnA_out.get());
166
+ rt.sync();
167
+
168
+ // Extract attnA_out[pos=5] into [1, 1, Q_DIM] for comparison
169
+ std::vector<uint16_t> refA_host(Q_DIM);
170
+ ACL_CHECK(aclrtMemcpy(refA_host.data(), Q_DIM*2,
171
+ (char*)attnA_out.get() + 5 * Q_DIM * 2, Q_DIM*2,
172
+ ACL_MEMCPY_DEVICE_TO_HOST));
173
+ printf(" attnA_out[5, :4] = %.5f %.5f %.5f %.5f\n",
174
+ bf16_to_float(refA_host[0]), bf16_to_float(refA_host[1]),
175
+ bf16_to_float(refA_host[2]), bf16_to_float(refA_host[3]));
176
+
177
+ // =========================================================================
178
+ // PATH B: 5-token prefill + KV cache → 1-token decode
179
+ // =========================================================================
180
+ printf("\n[Path B] 5-prefill + 1-decode via KV cache\n");
181
+
182
+ const int64_t MAX_LEN = 128; // small cache for test
183
+ DeviceBuffer k_cache(MAX_LEN * KV_DIM * 2);
184
+ DeviceBuffer v_cache(MAX_LEN * KV_DIM * 2);
185
+ // Zero-init unused slots (not strictly needed, FIAS uses actual_seq_lens).
186
+
187
+ // ---- Prefill 5 tokens ----
188
+ DeviceBuffer tokB_dev(S5 * 4);
189
+ ACL_CHECK(aclrtMemcpy(tokB_dev.get(), S5*4, tok6.data(), S5*4, ACL_MEMCPY_HOST_TO_DEVICE));
190
+ auto t_tokB = make_contig_tensor(tokB_dev.get(), ACL_INT32, {S5});
191
+ DeviceBuffer xB_dev(S5 * D * 2);
192
+ auto t_xB = make_contig_tensor(xB_dev.get(), ACL_BF16, {S5, D});
193
+ index_select(rt.stream(), t_embed_w.get(), 0, t_tokB.get(), t_xB.get());
194
+ rt.sync();
195
+
196
+ DeviceBuffer xnB_dev(S5 * D * 2);
197
+ DeviceBuffer rstdB_dev(S5 * 4);
198
+ auto t_xnB = make_contig_tensor(xnB_dev.get(), ACL_BF16, {S5, D});
199
+ auto t_rstdB = make_contig_tensor(rstdB_dev.get(), ACL_FLOAT, {S5});
200
+ rms_norm(rt.stream(), t_xB.get(), t_ln_w.get(), eps, t_xnB.get(), t_rstdB.get());
201
+
202
+ DeviceBuffer qB_dev(S5 * Q_DIM * 2);
203
+ DeviceBuffer kB_dev(S5 * KV_DIM * 2);
204
+ DeviceBuffer vB_dev(S5 * KV_DIM * 2);
205
+ auto t_qB = make_contig_tensor(qB_dev.get(), ACL_BF16, {S5, Q_DIM});
206
+ auto t_kB = make_contig_tensor(kB_dev.get(), ACL_BF16, {S5, KV_DIM});
207
+ auto t_vB = make_contig_tensor(vB_dev.get(), ACL_BF16, {S5, KV_DIM});
208
+ linear_hf(rt.stream(), t_xnB.get(), attn.q_proj.get(), ACL_BF16, Q_DIM, D, t_qB.get());
209
+ linear_hf(rt.stream(), t_xnB.get(), attn.k_proj.get(), ACL_BF16, KV_DIM, D, t_kB.get());
210
+ linear_hf(rt.stream(), t_xnB.get(), attn.v_proj.get(), ACL_BF16, KV_DIM, D, t_vB.get());
211
+
212
+ auto t_qB_4d = make_contig_tensor(qB_dev.get(), ACL_BF16, {1, S5, Hq, Dh});
213
+ auto t_kB_4d = make_contig_tensor(kB_dev.get(), ACL_BF16, {1, S5, Hkv, Dh});
214
+ DeviceBuffer rstd_qB(S5 * Hq * 4), rstd_kB(S5 * Hkv * 4);
215
+ auto t_rstd_qB = make_contig_tensor(rstd_qB.get(), ACL_FLOAT, {1, S5, Hq});
216
+ auto t_rstd_kB = make_contig_tensor(rstd_kB.get(), ACL_FLOAT, {1, S5, Hkv});
217
+ rms_norm(rt.stream(), t_qB_4d.get(), t_qn_w.get(), eps, t_qB_4d.get(), t_rstd_qB.get());
218
+ rms_norm(rt.stream(), t_kB_4d.get(), t_kn_w.get(), eps, t_kB_4d.get(), t_rstd_kB.get());
219
+
220
+ std::vector<uint16_t> cosB_h, sinB_h;
221
+ fill_cos_sin(cosB_h, sinB_h, 0, S5, Dh, theta);
222
+ DeviceBuffer cosB_dev(S5 * Dh * 2), sinB_dev(S5 * Dh * 2);
223
+ ACL_CHECK(aclrtMemcpy(cosB_dev.get(), S5*Dh*2, cosB_h.data(), S5*Dh*2, ACL_MEMCPY_HOST_TO_DEVICE));
224
+ ACL_CHECK(aclrtMemcpy(sinB_dev.get(), S5*Dh*2, sinB_h.data(), S5*Dh*2, ACL_MEMCPY_HOST_TO_DEVICE));
225
+ DeviceBuffer ropeB_scratch(1 * S5 * Hq * Dh * 2);
226
+ apply_rope_manual(rt.stream(), qB_dev.get(), 1, S5, Hq, Dh, kB_dev.get(), Hkv,
227
+ cosB_dev.get(), sinB_dev.get(), ropeB_scratch.get());
228
+ rt.sync();
229
+
230
+ // Append K, V to cache at positions 0..4.
231
+ ACL_CHECK(aclrtMemcpy(k_cache.get(), S5 * KV_DIM * 2,
232
+ kB_dev.get(), S5 * KV_DIM * 2, ACL_MEMCPY_DEVICE_TO_DEVICE));
233
+ ACL_CHECK(aclrtMemcpy(v_cache.get(), S5 * KV_DIM * 2,
234
+ vB_dev.get(), S5 * KV_DIM * 2, ACL_MEMCPY_DEVICE_TO_DEVICE));
235
+ printf(" cached K/V at positions 0..%ld\n", S5 - 1);
236
+
237
+ // ---- Decode 1 token (position = 5) ----
238
+ DeviceBuffer tokD_dev(1 * 4);
239
+ int32_t tok_dec = tok6[5];
240
+ ACL_CHECK(aclrtMemcpy(tokD_dev.get(), 4, &tok_dec, 4, ACL_MEMCPY_HOST_TO_DEVICE));
241
+ auto t_tokD = make_contig_tensor(tokD_dev.get(), ACL_INT32, {1});
242
+ DeviceBuffer xD_dev(1 * D * 2);
243
+ auto t_xD = make_contig_tensor(xD_dev.get(), ACL_BF16, {1, D});
244
+ index_select(rt.stream(), t_embed_w.get(), 0, t_tokD.get(), t_xD.get());
245
+
246
+ DeviceBuffer xnD_dev(1 * D * 2), rstdD_dev(1 * 4);
247
+ auto t_xnD = make_contig_tensor(xnD_dev.get(), ACL_BF16, {1, D});
248
+ auto t_rstdD = make_contig_tensor(rstdD_dev.get(), ACL_FLOAT, {1});
249
+ rms_norm(rt.stream(), t_xD.get(), t_ln_w.get(), eps, t_xnD.get(), t_rstdD.get());
250
+
251
+ DeviceBuffer qD_dev(1 * Q_DIM * 2), kD_dev(1 * KV_DIM * 2), vD_dev(1 * KV_DIM * 2);
252
+ auto t_qD = make_contig_tensor(qD_dev.get(), ACL_BF16, {1, Q_DIM});
253
+ auto t_kD = make_contig_tensor(kD_dev.get(), ACL_BF16, {1, KV_DIM});
254
+ auto t_vD = make_contig_tensor(vD_dev.get(), ACL_BF16, {1, KV_DIM});
255
+ linear_hf(rt.stream(), t_xnD.get(), attn.q_proj.get(), ACL_BF16, Q_DIM, D, t_qD.get());
256
+ linear_hf(rt.stream(), t_xnD.get(), attn.k_proj.get(), ACL_BF16, KV_DIM, D, t_kD.get());
257
+ linear_hf(rt.stream(), t_xnD.get(), attn.v_proj.get(), ACL_BF16, KV_DIM, D, t_vD.get());
258
+
259
+ auto t_qD_4d = make_contig_tensor(qD_dev.get(), ACL_BF16, {1, 1, Hq, Dh});
260
+ auto t_kD_4d = make_contig_tensor(kD_dev.get(), ACL_BF16, {1, 1, Hkv, Dh});
261
+ DeviceBuffer rstd_qD(1 * Hq * 4), rstd_kD(1 * Hkv * 4);
262
+ auto t_rstd_qD = make_contig_tensor(rstd_qD.get(), ACL_FLOAT, {1, 1, Hq});
263
+ auto t_rstd_kD = make_contig_tensor(rstd_kD.get(), ACL_FLOAT, {1, 1, Hkv});
264
+ rms_norm(rt.stream(), t_qD_4d.get(), t_qn_w.get(), eps, t_qD_4d.get(), t_rstd_qD.get());
265
+ rms_norm(rt.stream(), t_kD_4d.get(), t_kn_w.get(), eps, t_kD_4d.get(), t_rstd_kD.get());
266
+
267
+ // RoPE for position 5 only
268
+ std::vector<uint16_t> cosD_h, sinD_h;
269
+ fill_cos_sin(cosD_h, sinD_h, /*p0=*/5, /*L=*/1, Dh, theta);
270
+ DeviceBuffer cosD_dev(1 * Dh * 2), sinD_dev(1 * Dh * 2);
271
+ ACL_CHECK(aclrtMemcpy(cosD_dev.get(), Dh*2, cosD_h.data(), Dh*2, ACL_MEMCPY_HOST_TO_DEVICE));
272
+ ACL_CHECK(aclrtMemcpy(sinD_dev.get(), Dh*2, sinD_h.data(), Dh*2, ACL_MEMCPY_HOST_TO_DEVICE));
273
+ DeviceBuffer ropeD_scratch(1 * 1 * Hq * Dh * 2);
274
+ apply_rope_manual(rt.stream(), qD_dev.get(), 1, 1, Hq, Dh, kD_dev.get(), Hkv,
275
+ cosD_dev.get(), sinD_dev.get(), ropeD_scratch.get());
276
+ rt.sync();
277
+
278
+ // Append K, V to cache at position 5.
279
+ ACL_CHECK(aclrtMemcpy((char*)k_cache.get() + S5 * KV_DIM * 2, KV_DIM * 2,
280
+ kD_dev.get(), KV_DIM * 2, ACL_MEMCPY_DEVICE_TO_DEVICE));
281
+ ACL_CHECK(aclrtMemcpy((char*)v_cache.get() + S5 * KV_DIM * 2, KV_DIM * 2,
282
+ vD_dev.get(), KV_DIM * 2, ACL_MEMCPY_DEVICE_TO_DEVICE));
283
+
284
+ // FIAS decode: q [1, 1, Q_DIM], k/v [1, 6, KV_DIM] from cache.
285
+ auto t_qD_bsh = make_contig_tensor(qD_dev.get(), ACL_BF16, {1, 1, Q_DIM});
286
+ auto t_kC_bsh = make_contig_tensor(k_cache.get(), ACL_BF16, {1, S6, KV_DIM});
287
+ auto t_vC_bsh = make_contig_tensor(v_cache.get(), ACL_BF16, {1, S6, KV_DIM});
288
+
289
+ DeviceBuffer attnD_out(1 * 1 * Q_DIM * 2);
290
+ auto t_attnD_out = make_contig_tensor(attnD_out.get(), ACL_BF16, {1, 1, Q_DIM});
291
+ // Decode: q has 1 token, k/v have 6 tokens. Use sparse_mode=0 with no mask — the single q
292
+ // at the end can attend to all cached positions; there's no causal constraint on it.
293
+ fused_infer_attention_score(
294
+ rt.stream(), t_qD_bsh.get(), t_kC_bsh.get(), t_vC_bsh.get(),
295
+ nullptr, {1}, {S6},
296
+ Hq, Hkv, scale, 0, t_attnD_out.get());
297
+ rt.sync();
298
+
299
+ std::vector<uint16_t> decB_host(Q_DIM);
300
+ ACL_CHECK(aclrtMemcpy(decB_host.data(), Q_DIM*2, attnD_out.get(), Q_DIM*2, ACL_MEMCPY_DEVICE_TO_HOST));
301
+
302
+ // ---- Compare Path A vs Path B ----
303
+ printf("\n attnB_decode[:4] = %.5f %.5f %.5f %.5f\n",
304
+ bf16_to_float(decB_host[0]), bf16_to_float(decB_host[1]),
305
+ bf16_to_float(decB_host[2]), bf16_to_float(decB_host[3]));
306
+
307
+ double l2d = 0, l2r = 0, maxd = 0;
308
+ for (int i = 0; i < Q_DIM; i++) {
309
+ float a = bf16_to_float(decB_host[i]), b = bf16_to_float(refA_host[i]);
310
+ l2d += (a-b)*(a-b); l2r += b*b;
311
+ if (std::abs(a-b) > maxd) maxd = std::abs(a-b);
312
+ }
313
+ double rel = std::sqrt(l2d) / (std::sqrt(l2r) + 1e-10);
314
+ printf("\nDecode vs 6-prefill comparison: rel=%.4e max_abs=%.4f\n", rel, maxd);
315
+
316
+ bool pass = rel < 5e-2;
317
+ printf("\n%s\n", pass ? "=== test_attention_decode PASS ===" : "=== test_attention_decode FAIL ===");
318
+ return pass ? 0 : 1;
319
+ }
tests/test_attention_layer.cpp ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // test_attention_layer.cpp — full single-layer attention forward (Qwen3-235B layer 0), TP=1.
2
+ // Validates C++ output against Python HF-style reference (attn_data/final_out.bin).
3
+ #include "acl_common.h"
4
+ #include "acl_runtime.h"
5
+ #include "aclnn_ops.h"
6
+ #include "device_weights.h"
7
+ #include "model_config.h"
8
+ #include "rope.h"
9
+ #include "safetensors_loader.h"
10
+
11
+ #include <cmath>
12
+ #include <cstdio>
13
+ #include <cstring>
14
+ #include <fstream>
15
+ #include <vector>
16
+
17
+ static float bf16_to_float(uint16_t x) {
18
+ uint32_t u = (uint32_t)x << 16; float f; std::memcpy(&f, &u, 4); return f;
19
+ }
20
+ static uint16_t float_to_bf16(float x) {
21
+ uint32_t u; std::memcpy(&u, &x, 4);
22
+ return (uint16_t)((u + 0x7FFF + ((u >> 16) & 1)) >> 16);
23
+ }
24
+
25
+ static std::vector<uint8_t> read_file(const std::string& p) {
26
+ std::ifstream f(p, std::ios::binary | std::ios::ate); size_t s = f.tellg();
27
+ f.seekg(0); std::vector<uint8_t> v(s); f.read((char*)v.data(), s); return v;
28
+ }
29
+
30
+ int main() {
31
+ const std::string model_dir = "/path/to/Qwen3-235B-A22B-Instruct-2507-BF16";
32
+ const std::string data_dir = "tests/attn_data";
33
+
34
+ ModelConfig cfg;
35
+ if (!cfg.load_from_json(model_dir + "/config.json")) return 1;
36
+ cfg.compute_derived(/*tp_size=*/1, /*tp_rank=*/0); // single rank for correctness test
37
+ const int64_t D = cfg.hidden_size;
38
+ const int64_t Hq = cfg.num_attention_heads;
39
+ const int64_t Hkv = cfg.num_key_value_heads;
40
+ const int64_t Dh = cfg.head_dim;
41
+ const int64_t Q_DIM = Hq * Dh;
42
+ const int64_t KV_DIM = Hkv * Dh;
43
+ const double scale = 1.0 / std::sqrt((double)Dh);
44
+ const double eps = cfg.rms_norm_eps;
45
+ const float theta = cfg.rope_theta;
46
+
47
+ SafetensorsLoader st;
48
+ if (!st.open(model_dir)) return 1;
49
+
50
+ AclRuntime rt;
51
+ rt.init(0);
52
+
53
+ // ---- Load weights (layer 0 attention + embed) ----
54
+ DeviceWeightsLoader dw(st, cfg);
55
+ SharedWeights shared;
56
+ LayerAttnWeights attn;
57
+ printf("Loading weights...\n");
58
+ if (!dw.load_shared(shared)) return 1;
59
+ if (!dw.load_attention(0, attn)) return 1;
60
+ printf(" shared.embed %.0fMB, attn total ~140MB\n", shared.embed_tokens.size / 1e6);
61
+
62
+ // ---- Load token ids (5 tokens: "The capital of France is") ----
63
+ auto tok_raw = read_file(data_dir + "/token_ids.bin");
64
+ int32_t S = *(int32_t*)tok_raw.data();
65
+ std::vector<int32_t> tokens(S);
66
+ std::memcpy(tokens.data(), tok_raw.data() + 4, S * 4);
67
+ printf("S=%d tokens=[", S); for (auto t : tokens) printf("%d,", t); printf("]\n");
68
+
69
+ // ---- Embed lookup: [S, D] ----
70
+ DeviceBuffer tok_dev(S * 4);
71
+ ACL_CHECK(aclrtMemcpy(tok_dev.get(), S * 4, tokens.data(), S * 4, ACL_MEMCPY_HOST_TO_DEVICE));
72
+ auto t_tok = make_contig_tensor(tok_dev.get(), ACL_INT32, {S});
73
+
74
+ // embed weight shape [vocab, D]
75
+ auto t_embed_w = make_contig_tensor(shared.embed_tokens.get(), ACL_BF16, {cfg.vocab_size, D});
76
+
77
+ DeviceBuffer x_dev(S * D * 2);
78
+ auto t_x = make_contig_tensor(x_dev.get(), ACL_BF16, {S, D});
79
+ index_select(rt.stream(), t_embed_w.get(), 0, t_tok.get(), t_x.get());
80
+ rt.sync();
81
+
82
+ // ---- Residual snapshot (copy x) ----
83
+ DeviceBuffer residual_dev(S * D * 2);
84
+ ACL_CHECK(aclrtMemcpyAsync(residual_dev.get(), S*D*2, x_dev.get(), S*D*2, ACL_MEMCPY_DEVICE_TO_DEVICE, rt.stream()));
85
+
86
+ // ---- Input layernorm ----
87
+ DeviceBuffer xn_dev(S * D * 2);
88
+ DeviceBuffer rstd_dev(S * 4);
89
+ auto t_xn = make_contig_tensor(xn_dev.get(), ACL_BF16, {S, D});
90
+ auto t_ln_w = make_contig_tensor(attn.input_layernorm.get(), ACL_BF16, {D});
91
+ auto t_rstd = make_contig_tensor(rstd_dev.get(), ACL_FLOAT, {S});
92
+ rms_norm(rt.stream(), t_x.get(), t_ln_w.get(), eps, t_xn.get(), t_rstd.get());
93
+
94
+ // ---- Q/K/V projections (linear_hf: y = x @ W.T, W stored as [out, in]) ----
95
+ DeviceBuffer q_dev(S * Q_DIM * 2);
96
+ DeviceBuffer k_dev(S * KV_DIM * 2);
97
+ DeviceBuffer v_dev(S * KV_DIM * 2);
98
+ auto t_q = make_contig_tensor(q_dev.get(), ACL_BF16, {S, Q_DIM});
99
+ auto t_k = make_contig_tensor(k_dev.get(), ACL_BF16, {S, KV_DIM});
100
+ auto t_v = make_contig_tensor(v_dev.get(), ACL_BF16, {S, KV_DIM});
101
+ linear_hf(rt.stream(), t_xn.get(), attn.q_proj.get(), ACL_BF16, Q_DIM, D, t_q.get());
102
+ linear_hf(rt.stream(), t_xn.get(), attn.k_proj.get(), ACL_BF16, KV_DIM, D, t_k.get());
103
+ linear_hf(rt.stream(), t_xn.get(), attn.v_proj.get(), ACL_BF16, KV_DIM, D, t_v.get());
104
+
105
+ // ---- Reshape Q, K as [B=1, S, N, Dh] for q_norm/k_norm + RoPE ----
106
+ // Same memory; just new views.
107
+ // q_dev has S * Q_DIM = S * Hq * Dh BF16
108
+ auto t_q_4d = make_contig_tensor(q_dev.get(), ACL_BF16, {1, S, Hq, Dh});
109
+ auto t_k_4d = make_contig_tensor(k_dev.get(), ACL_BF16, {1, S, Hkv, Dh});
110
+
111
+ // Per-head RmsNorm on last dim (gamma shape [Dh])
112
+ auto t_qn_w = make_contig_tensor(attn.q_norm.get(), ACL_BF16, {Dh});
113
+ auto t_kn_w = make_contig_tensor(attn.k_norm.get(), ACL_BF16, {Dh});
114
+ DeviceBuffer rstd_q_dev(S * Hq * 4); // rstd shape = q's all-but-last dims
115
+ DeviceBuffer rstd_k_dev(S * Hkv * 4);
116
+ auto t_rstd_q = make_contig_tensor(rstd_q_dev.get(), ACL_FLOAT, {1, S, Hq});
117
+ auto t_rstd_k = make_contig_tensor(rstd_k_dev.get(), ACL_FLOAT, {1, S, Hkv});
118
+ // RmsNorm in place on q/k
119
+ rms_norm(rt.stream(), t_q_4d.get(), t_qn_w.get(), eps, t_q_4d.get(), t_rstd_q.get());
120
+ rms_norm(rt.stream(), t_k_4d.get(), t_kn_w.get(), eps, t_k_4d.get(), t_rstd_k.get());
121
+
122
+ // ---- Compute cos/sin on device ----
123
+ // cos/sin shape [1, S, Dh] BF16
124
+ std::vector<uint16_t> cos_host(S * Dh), sin_host(S * Dh);
125
+ for (int s = 0; s < S; s++) {
126
+ for (int64_t d = 0; d < Dh; d++) {
127
+ // freq index: for half-half layout, index d corresponds to pair index (d % (Dh/2))
128
+ int64_t half = Dh / 2;
129
+ int64_t pair = (d < half) ? d : (d - half);
130
+ float theta_pair = 1.0f / std::pow(theta, (2.0f * pair) / Dh);
131
+ float angle = (float)s * theta_pair;
132
+ cos_host[s * Dh + d] = float_to_bf16(std::cos(angle));
133
+ sin_host[s * Dh + d] = float_to_bf16(std::sin(angle));
134
+ }
135
+ }
136
+ DeviceBuffer cos_dev(S * Dh * 2);
137
+ DeviceBuffer sin_dev(S * Dh * 2);
138
+ ACL_CHECK(aclrtMemcpy(cos_dev.get(), S*Dh*2, cos_host.data(), S*Dh*2, ACL_MEMCPY_HOST_TO_DEVICE));
139
+ ACL_CHECK(aclrtMemcpy(sin_dev.get(), S*Dh*2, sin_host.data(), S*Dh*2, ACL_MEMCPY_HOST_TO_DEVICE));
140
+
141
+ // ---- RoPE ----
142
+ DeviceBuffer rope_scratch(1 * S * Hq * Dh * 2);
143
+ apply_rope_manual(rt.stream(),
144
+ q_dev.get(), 1, S, Hq, Dh,
145
+ k_dev.get(), Hkv,
146
+ cos_dev.get(), sin_dev.get(),
147
+ rope_scratch.get());
148
+
149
+ // ---- FIAS ----
150
+ // q/k/v are reshaped back to BSH [1, S, Hq*Dh or Hkv*Dh]
151
+ auto t_q_bsh = make_contig_tensor(q_dev.get(), ACL_BF16, {1, S, Q_DIM});
152
+ auto t_k_bsh = make_contig_tensor(k_dev.get(), ACL_BF16, {1, S, KV_DIM});
153
+ auto t_v_bsh = make_contig_tensor(v_dev.get(), ACL_BF16, {1, S, KV_DIM});
154
+
155
+ // Causal mask 2048x2048 (sparse_mode=3 requires fixed size)
156
+ const int64_t MASK = 2048;
157
+ DeviceBuffer mask_dev(MASK * MASK); // bool = 1 byte
158
+ std::vector<uint8_t> mask_host(MASK * MASK, 0);
159
+ for (int i = 0; i < MASK; i++)
160
+ for (int j = i+1; j < MASK; j++)
161
+ mask_host[i*MASK + j] = 1; // upper triangular = True
162
+ ACL_CHECK(aclrtMemcpy(mask_dev.get(), MASK*MASK, mask_host.data(), MASK*MASK, ACL_MEMCPY_HOST_TO_DEVICE));
163
+ auto t_mask = make_contig_tensor(mask_dev.get(), ACL_BOOL, {1, 1, MASK, MASK});
164
+
165
+ DeviceBuffer attn_out_dev(1 * S * Q_DIM * 2);
166
+ auto t_attn_out = make_contig_tensor(attn_out_dev.get(), ACL_BF16, {1, S, Q_DIM});
167
+
168
+ fused_infer_attention_score(
169
+ rt.stream(),
170
+ t_q_bsh.get(), t_k_bsh.get(), t_v_bsh.get(),
171
+ t_mask.get(),
172
+ {S}, {S},
173
+ Hq, Hkv,
174
+ scale,
175
+ 3, // sparse_mode = causal
176
+ t_attn_out.get());
177
+
178
+ // ---- O projection ----
179
+ auto t_attn_out_2d = make_contig_tensor(attn_out_dev.get(), ACL_BF16, {S, Q_DIM});
180
+ DeviceBuffer o_dev(S * D * 2);
181
+ auto t_o = make_contig_tensor(o_dev.get(), ACL_BF16, {S, D});
182
+ linear_hf(rt.stream(), t_attn_out_2d.get(), attn.o_proj.get(), ACL_BF16, D, Q_DIM, t_o.get());
183
+
184
+ // ---- Residual add: out = residual + o ----
185
+ auto t_res = make_contig_tensor(residual_dev.get(), ACL_BF16, {S, D});
186
+ float alpha_v = 1.0f;
187
+ aclScalar* alpha = aclCreateScalar(&alpha_v, ACL_FLOAT);
188
+ DeviceBuffer out_dev(S * D * 2);
189
+ auto t_out = make_contig_tensor(out_dev.get(), ACL_BF16, {S, D});
190
+ {
191
+ uint64_t ws = 0; aclOpExecutor* e = nullptr;
192
+ ACLNN_CHECK(aclnnAddGetWorkspaceSize(t_res.get(), t_o.get(), alpha, t_out.get(), &ws, &e));
193
+ DeviceBuffer wb; if (ws > 0) wb.alloc(ws);
194
+ ACLNN_CHECK(aclnnAdd(wb.get(), ws, e, rt.stream()));
195
+ }
196
+ aclDestroyScalar(alpha);
197
+ rt.sync();
198
+
199
+ // ---- Compare with Python reference ----
200
+ auto ref_h = read_file(data_dir + "/final_out.bin");
201
+ std::vector<uint16_t> cxx(S * D);
202
+ ACL_CHECK(aclrtMemcpy(cxx.data(), S*D*2, out_dev.get(), S*D*2, ACL_MEMCPY_DEVICE_TO_HOST));
203
+ auto* ref = (const uint16_t*)ref_h.data();
204
+
205
+ double l2d = 0, l2r = 0, maxd = 0;
206
+ for (int i = 0; i < S * D; i++) {
207
+ float a = bf16_to_float(cxx[i]), b = bf16_to_float(ref[i]);
208
+ l2d += (a-b)*(a-b); l2r += b*b;
209
+ if (std::abs(a-b) > maxd) maxd = std::abs(a-b);
210
+ }
211
+ double rel = std::sqrt(l2d) / (std::sqrt(l2r) + 1e-10);
212
+ printf("\nAttention layer output compare: rel=%.4e max_abs=%.4f\n", rel, maxd);
213
+ printf(" cxx[0, :4]: "); for (int i = 0; i < 4; i++) printf("%.6f ", bf16_to_float(cxx[i]));
214
+ printf("\n ref[0, :4]: "); for (int i = 0; i < 4; i++) printf("%.6f ", bf16_to_float(ref[i])); printf("\n");
215
+
216
+ bool pass = rel < 5e-2; // BF16 accumulation across 5+ ops loses ~1-2% per step
217
+ printf("\n%s\n", pass ? "=== test_attention_layer PASS ===" : "=== test_attention_layer FAIL ===");
218
+ return pass ? 0 : 1;
219
+ }
tests/test_batch_correctness.cpp ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // test_batch_correctness.cpp — verify that forward with S>1 at past_len>0 produces the
2
+ // same logits at each position as sequential S=1 decodes.
3
+ //
4
+ // This is the foundation for speculative decoding / PLD: the main model must predict logits
5
+ // for each of K candidate positions in one batched forward pass matching sequential behavior.
6
+ #include "runner.h"
7
+
8
+ #include <cstdio>
9
+ #include <cstring>
10
+ #include <vector>
11
+ #include <cmath>
12
+
13
+ static float bf16_to_float(uint16_t x) {
14
+ uint32_t u = (uint32_t)x << 16; float f; std::memcpy(&f, &u, 4); return f;
15
+ }
16
+
17
+ int main() {
18
+ const std::string model_dir = "/path/to/Qwen3-235B-A22B-Instruct-2507-BF16";
19
+ int tp_rank = 0, tp_size = 1;
20
+ if (const char* v = std::getenv("TP_RANK")) tp_rank = std::atoi(v);
21
+ if (const char* v = std::getenv("TP_SIZE")) tp_size = std::atoi(v);
22
+ bool is_master = tp_rank == 0;
23
+
24
+ Runner r;
25
+ if (!r.init(model_dir, tp_size, tp_rank, 94, 512)) return 1;
26
+ const int64_t V = r.cfg().vocab_size;
27
+
28
+ // Prefix
29
+ std::vector<int32_t> prompt = {785, 6722, 315, 9625, 374};
30
+ DeviceBuffer logits0;
31
+ r.prefill(prompt.data(), prompt.size(), logits0);
32
+ std::vector<uint16_t> h_last0(V);
33
+ if (is_master) ACL_CHECK(aclrtMemcpy(h_last0.data(), V*2, logits0.get(), V*2, ACL_MEMCPY_DEVICE_TO_HOST));
34
+ int next0 = 0;
35
+ if (is_master) {
36
+ float best = -1e30; for (int i = 0; i < V; i++) { float v = bf16_to_float(h_last0[i]); if (v > best) { best = v; next0 = i; } }
37
+ }
38
+ // Broadcast next0 (simple: let rank 0 decide and non-master ranks independently too)
39
+ int32_t token_seq[4];
40
+ if (is_master) token_seq[0] = next0;
41
+
42
+ // --- Path A: sequential S=1 decode × 4 times ---
43
+ std::vector<std::vector<uint16_t>> seq_logits(4);
44
+ for (int i = 0; i < 4; i++) seq_logits[i].resize(V);
45
+
46
+ // first decode: takes prompt's last logit argmax
47
+ // Here we need identical approach on all ranks. Use random token id for consistency.
48
+ std::vector<int32_t> seq_tokens = {next0, 100, 200, 300}; // deterministic for test
49
+
50
+ for (int i = 0; i < 4; i++) {
51
+ DeviceBuffer out;
52
+ r.decode(seq_tokens[i], out);
53
+ if (is_master) ACL_CHECK(aclrtMemcpy(seq_logits[i].data(), V*2, out.get(), V*2, ACL_MEMCPY_DEVICE_TO_HOST));
54
+ }
55
+ int64_t past_after_seq = r.past_len();
56
+
57
+ // --- Path B: reset, re-prefill, then ONE batch forward with S=4 ---
58
+ r.reset_cache();
59
+ DeviceBuffer logits_reprefill;
60
+ r.prefill(prompt.data(), prompt.size(), logits_reprefill);
61
+
62
+ DeviceBuffer batch_logits;
63
+ r.prefill(seq_tokens.data(), 4, batch_logits);
64
+ // prefill returns logits for LAST position only (S=4 gives [1, V], not [4, V]).
65
+ // Hmm — that's a limitation. To do PLD we need logits for all 4 positions.
66
+ // For now, just compare the LAST one (position 4 after prefix).
67
+
68
+ std::vector<uint16_t> batch_last(V);
69
+ if (is_master) ACL_CHECK(aclrtMemcpy(batch_last.data(), V*2, batch_logits.get(), V*2, ACL_MEMCPY_DEVICE_TO_HOST));
70
+
71
+ if (is_master) {
72
+ printf("\n=== Batched vs Sequential Decode Correctness ===\n");
73
+ double l2d=0, l2r=0, maxd=0;
74
+ for (int i = 0; i < V; i++) {
75
+ float a = bf16_to_float(batch_last[i]), b = bf16_to_float(seq_logits[3][i]);
76
+ l2d += (a-b)*(a-b); l2r += b*b;
77
+ if (std::abs(a-b) > maxd) maxd = std::abs(a-b);
78
+ }
79
+ double rel = std::sqrt(l2d) / (std::sqrt(l2r) + 1e-10);
80
+ printf("Last-position logits:\n");
81
+ printf(" seq[3] argmax = "); {
82
+ int b = 0; float bv = bf16_to_float(seq_logits[3][0]);
83
+ for (int i = 1; i < V; i++) if (bf16_to_float(seq_logits[3][i]) > bv) { bv = bf16_to_float(seq_logits[3][i]); b = i; }
84
+ printf("%d (%.3f)\n", b, bv);
85
+ }
86
+ printf(" batch argmax = "); {
87
+ int b = 0; float bv = bf16_to_float(batch_last[0]);
88
+ for (int i = 1; i < V; i++) if (bf16_to_float(batch_last[i]) > bv) { bv = bf16_to_float(batch_last[i]); b = i; }
89
+ printf("%d (%.3f)\n", b, bv);
90
+ }
91
+ printf(" rel=%.4e max=%.4f\n", rel, maxd);
92
+ printf(" %s\n", rel < 5e-2 ? "PASS" : "FAIL (batch forward diverges from sequential)");
93
+ printf("\nNote: current Runner.prefill() returns ONLY last-position logits. For PLD\n");
94
+ printf("we need all-position logits: requires extending prefill to optionally output\n");
95
+ printf("[S, V] logits tensor.\n");
96
+ }
97
+ return 0;
98
+ }
tests/test_batch_decode.cpp ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // test_batch_decode.cpp — benchmark decode with different batch sizes S = 1, 2, 4, 8.
2
+ //
3
+ // Purpose: quantify the cost of "batched decode" (a.k.a. the ingredient speculative decoding
4
+ // relies on). If Runner.prefill(S=K) forward-pass is only a small overhead over S=1, then
5
+ // spec-decoding with K draft tokens gives ~K× speedup at high accept rate.
6
+ //
7
+ // Per-token amortized cost:
8
+ // cost(S) / S
9
+ // Speculative decoding benefit:
10
+ // expected_accept_rate * K = effective tokens per forward
11
+ // → TG = expected / (cost(S=K+1) / 1_sec)
12
+ #include "runner.h"
13
+
14
+ #include <chrono>
15
+ #include <cstdio>
16
+ #include <cstring>
17
+ #include <vector>
18
+
19
+ int main() {
20
+ const std::string model_dir = "/path/to/Qwen3-235B-A22B-Instruct-2507-BF16";
21
+ Runner r;
22
+ int tp_rank = 0, tp_size = 1;
23
+ if (const char* v = std::getenv("TP_RANK")) tp_rank = std::atoi(v);
24
+ if (const char* v = std::getenv("TP_SIZE")) tp_size = std::atoi(v);
25
+ bool is_master = tp_rank == 0;
26
+
27
+ if (!r.init(model_dir, tp_size, tp_rank, /*num_layers=*/94, /*max_seq=*/512)) return 1;
28
+
29
+ // Prefill a short context so decode has some KV cache
30
+ std::vector<int32_t> prompt = {785, 6722, 315, 9625, 374}; // "The capital of France is"
31
+ DeviceBuffer logits;
32
+ r.prefill(prompt.data(), prompt.size(), logits);
33
+
34
+ auto now = []() { return std::chrono::steady_clock::now(); };
35
+ auto ms = [](auto t0, auto t1) { return std::chrono::duration<double, std::milli>(t1 - t0).count(); };
36
+
37
+ std::vector<int> batch_sizes = {1, 2, 4, 8};
38
+ int N_ITERS = 20;
39
+
40
+ if (is_master) {
41
+ printf("\n=== Batched decode forward benchmark (94 layers, TP=%d) ===\n", tp_size);
42
+ printf("Each row: forward with S=K new tokens after prefill\n");
43
+ printf("%-5s %-12s %-18s %-18s %s\n",
44
+ "S", "ms/forward", "ms/token (amort)", "tokens/sec", "vs S=1 efficiency");
45
+ }
46
+
47
+ double base_per_token = 0;
48
+ for (int S : batch_sizes) {
49
+ // Reset cache between measurements to keep cache size fair (same position for each)
50
+ // Actually we want to simulate: after some past_len, do 1 forward with S new tokens.
51
+ // Use prefill() which accepts S>=1.
52
+
53
+ std::vector<double> times;
54
+ for (int iter = 0; iter < N_ITERS + 3; iter++) { // +3 for warmup
55
+ r.reset_cache();
56
+ r.prefill(prompt.data(), prompt.size(), logits); // re-prefill
57
+
58
+ // New forward with S tokens (as if doing speculative verify)
59
+ std::vector<int32_t> new_tokens(S, 100); // dummy token ids
60
+ auto t0 = now();
61
+ DeviceBuffer logits2;
62
+ r.prefill(new_tokens.data(), S, logits2);
63
+ auto t1 = now();
64
+ if (iter >= 3) times.push_back(ms(t0, t1));
65
+ }
66
+ std::sort(times.begin(), times.end());
67
+ double median_ms = times[times.size() / 2];
68
+ double per_token = median_ms / S;
69
+ double tok_per_sec = 1000.0 / per_token;
70
+ if (S == 1) base_per_token = per_token;
71
+ double efficiency = base_per_token / per_token * 100.0;
72
+
73
+ if (is_master) {
74
+ printf("%-5d %-12.2f %-18.2f %-18.2f %.1f%%\n",
75
+ S, median_ms, per_token, tok_per_sec, efficiency);
76
+ }
77
+ }
78
+
79
+ if (is_master) {
80
+ printf("\n=== Interpretation ===\n");
81
+ printf("If S=4 forward ~ S=1 (efficiency high), spec decoding with accept_rate=70%%\n");
82
+ printf("gives TG = 0.7*4 / cost(S=5) vs baseline 1 / cost(S=1) = up to 2.8× speedup.\n");
83
+ }
84
+ return 0;
85
+ }
tests/test_chat_flow.sh ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ # test_chat_flow.sh — end-to-end integration smoke test for the CLI.
3
+ #
4
+ # Exercises:
5
+ # - --prompt-file
6
+ # - Multi-turn --chat memory (remembers Alice's name in turn 2)
7
+ # - --reset command in REPL
8
+ # - --system prompt
9
+ # - EOS detection at <|im_end|>
10
+ #
11
+ # Requires TP=16 Ascend 910 setup. Run from the repo root.
12
+ #
13
+ # Exit: 0 on all-pass, nonzero with reason.
14
+ set -u
15
+ BIN="./build/qwen3-moe-aclnn"
16
+ MODEL="${MODEL_DIR:-/path/to/Qwen3-235B-A22B-Instruct-2507-BF16}"
17
+ LAUNCH="./scripts/tp_launch.sh"
18
+ TP="${TP_SIZE:-16}"
19
+ VOCAB="tokenizer_data/vocab.bin"
20
+
21
+ [ -x "$BIN" ] || { echo "FAIL: $BIN not built"; exit 1; }
22
+ [ -x "$LAUNCH" ] || { echo "FAIL: $LAUNCH not found"; exit 1; }
23
+
24
+ pass=0; fail=0
25
+ check() {
26
+ local name="$1"; shift
27
+ local out="$1"; shift
28
+ local needle="$1"; shift
29
+ if echo "$out" | grep -qiF "$needle"; then
30
+ echo " [PASS] $name (found: '$needle')"; pass=$((pass+1))
31
+ else
32
+ echo " [FAIL] $name (did NOT find: '$needle')"; fail=$((fail+1))
33
+ echo " ---- output ----"; echo "$out" | tail -20; echo " ---- end ----"
34
+ fi
35
+ }
36
+
37
+ echo "===== Test 1: --prompt-file + EOS ====="
38
+ echo "What is the capital of Japan?" > /tmp/chat_test_prompt.txt
39
+ OUT=$(${LAUNCH} ${TP} ${BIN} --model-dir "$MODEL" \
40
+ --prompt-file /tmp/chat_test_prompt.txt \
41
+ --chat --n-predict 50 --temperature 0 --vocab "$VOCAB" 2>&1)
42
+ check "prompt-file loaded" "$OUT" "capital of Japan"
43
+ check "answer mentions Tokyo" "$OUT" "Tokyo"
44
+ check "hit EOS" "$OUT" "hit EOS"
45
+
46
+ echo ""
47
+ echo "===== Test 2: multi-turn memory (remembers name) ====="
48
+ OUT=$(printf "My name is Alice.\nWhat is my name?\nquit\n" | \
49
+ ${LAUNCH} ${TP} ${BIN} --model-dir "$MODEL" \
50
+ --interactive --chat \
51
+ --system "You are a concise assistant. Answer in one short sentence." \
52
+ --temperature 0 --n-predict 40 --max-seq 512 \
53
+ --vocab "$VOCAB" 2>&1)
54
+ check "recalls Alice" "$OUT" "Alice"
55
+ check "has 2 turns" "$OUT" "past_len="
56
+
57
+ echo ""
58
+ echo "===== Test 3: reset command clears memory ====="
59
+ OUT=$(printf "My name is Bob.\nreset\nWhat is my name?\nquit\n" | \
60
+ ${LAUNCH} ${TP} ${BIN} --model-dir "$MODEL" \
61
+ --interactive --chat \
62
+ --system "Answer truthfully in one sentence." \
63
+ --temperature 0 --n-predict 40 --max-seq 512 \
64
+ --vocab "$VOCAB" 2>&1)
65
+ check "reset acknowledged" "$OUT" "cache + conversation reset"
66
+ # After reset, model should NOT know the name is Bob (probably says "don't know" or asks)
67
+ # We can't reliably check negation, so just check that the reset ran and turn 3 produced output
68
+ check "turn 3 ran" "$OUT" "bye"
69
+
70
+ echo ""
71
+ echo "===== Summary: $pass passed, $fail failed ====="
72
+ exit $fail
tests/test_engine_smoke.cpp ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ // test_engine_smoke.cpp — just verify engine.h compiles and links.
2
+ #include "engine.h"
3
+
4
+ int main() {
5
+ // No-op — all engine functions take many parameters and need real runtime. This test
6
+ // only validates that the header compiles and the core lib links.
7
+ return 0;
8
+ }
tests/test_layer_forward.cpp ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // test_layer_forward.cpp — integration test for one full transformer layer via engine.h.
2
+ //
3
+ // Chain: embed_5_tokens → attention_forward (prefill, past=0) → +residual → moe_forward → +residual
4
+ // Expected: final output matches moe_data/final_out.bin within BF16 precision (rel < 5e-2).
5
+ #include "acl_common.h"
6
+ #include "acl_runtime.h"
7
+ #include "aclnn_ops.h"
8
+ #include "device_weights.h"
9
+ #include "engine.h"
10
+ #include "model_config.h"
11
+ #include "safetensors_loader.h"
12
+
13
+ #include <cmath>
14
+ #include <cstdio>
15
+ #include <cstring>
16
+ #include <fstream>
17
+ #include <vector>
18
+
19
+ static float bf16_to_float(uint16_t x) {
20
+ uint32_t u = (uint32_t)x << 16; float f; std::memcpy(&f, &u, 4); return f;
21
+ }
22
+ static std::vector<uint8_t> read_file(const std::string& p) {
23
+ std::ifstream f(p, std::ios::binary | std::ios::ate); size_t s = f.tellg();
24
+ f.seekg(0); std::vector<uint8_t> v(s); f.read((char*)v.data(), s); return v;
25
+ }
26
+
27
+ // Add: out = a + b (BF16).
28
+ static void bf16_add(aclrtStream stream, aclTensor* a, aclTensor* b, aclTensor* out) {
29
+ float alpha = 1.0f; aclScalar* al = aclCreateScalar(&alpha, ACL_FLOAT);
30
+ uint64_t ws = 0; aclOpExecutor* e = nullptr;
31
+ ACLNN_CHECK(aclnnAddGetWorkspaceSize(a, b, al, out, &ws, &e));
32
+ DeviceBuffer wb; if (ws > 0) wb.alloc(ws);
33
+ ACLNN_CHECK(aclnnAdd(wb.get(), ws, e, stream));
34
+ aclDestroyScalar(al);
35
+ }
36
+
37
+ int main() {
38
+ const std::string model_dir = "/path/to/Qwen3-235B-A22B-Instruct-2507-BF16";
39
+ const std::string attn_data = "tests/attn_data";
40
+ const std::string moe_data = "tests/moe_data";
41
+
42
+ ModelConfig cfg;
43
+ if (!cfg.load_from_json(model_dir + "/config.json")) return 1;
44
+ cfg.compute_derived(1, 0);
45
+ const int64_t D = cfg.hidden_size;
46
+ const int64_t Hq = cfg.n_heads_per_rank;
47
+ const int64_t Hkv = cfg.n_kv_heads_per_rank;
48
+ const int64_t Dh = cfg.head_dim;
49
+ const int64_t Q_DIM = Hq * Dh;
50
+ const int64_t KV_DIM = Hkv * Dh;
51
+ const int64_t I = cfg.i_per_rank;
52
+ const int64_t E = cfg.num_experts;
53
+ const int64_t K = cfg.num_experts_per_tok;
54
+ printf("Dims: D=%ld Q_DIM=%ld KV_DIM=%ld I=%ld E=%ld K=%ld\n", D, Q_DIM, KV_DIM, I, E, K);
55
+
56
+ SafetensorsLoader st;
57
+ if (!st.open(model_dir)) return 1;
58
+ AclRuntime rt;
59
+ rt.init(0);
60
+
61
+ DeviceWeightsLoader dw(st, cfg);
62
+ SharedWeights shared;
63
+ LayerAttnWeights attn;
64
+ LayerMoEWeights moe;
65
+ printf("Loading weights...\n");
66
+ if (!dw.load_shared(shared)) return 1;
67
+ if (!dw.load_attention(0, attn)) return 1;
68
+ if (!dw.load_moe(0, rt.stream(), moe)) return 1;
69
+ rt.sync();
70
+
71
+ // ---- Load 5 prefill tokens ----
72
+ auto tok_raw = read_file(attn_data + "/token_ids.bin");
73
+ int32_t S = *(int32_t*)tok_raw.data();
74
+ std::vector<int32_t> tokens(S);
75
+ std::memcpy(tokens.data(), tok_raw.data() + 4, S * 4);
76
+ printf("S=%d tokens=[", S); for (auto t : tokens) printf("%d,", t); printf("]\n");
77
+
78
+ // ---- Embed ----
79
+ DeviceBuffer tok_dev(S * 4);
80
+ ACL_CHECK(aclrtMemcpy(tok_dev.get(), S * 4, tokens.data(), S * 4, ACL_MEMCPY_HOST_TO_DEVICE));
81
+ auto t_tok = make_contig_tensor(tok_dev.get(), ACL_INT32, {S});
82
+ auto t_embed_w = make_contig_tensor(shared.embed_tokens.get(), ACL_BF16, {cfg.vocab_size, D});
83
+
84
+ DeviceBuffer x_dev(S * D * 2); // residual / input to layer
85
+ auto t_x = make_contig_tensor(x_dev.get(), ACL_BF16, {S, D});
86
+ index_select(rt.stream(), t_embed_w.get(), 0, t_tok.get(), t_x.get());
87
+ rt.sync();
88
+
89
+ // ---- Scratch buffers for attention_forward ----
90
+ const int64_t MAX_LEN = 128;
91
+ DeviceBuffer k_cache(MAX_LEN * KV_DIM * 2), v_cache(MAX_LEN * KV_DIM * 2);
92
+ DeviceBuffer q_sc(S * Q_DIM * 2), k_sc(S * KV_DIM * 2), v_sc(S * KV_DIM * 2);
93
+ DeviceBuffer xn_sc(S * D * 2), rstd_sc(S * std::max(Hq, Hkv) * 4);
94
+ DeviceBuffer rope_sc(1 * S * Hq * Dh * 2);
95
+ DeviceBuffer attn_fias_sc(S * Q_DIM * 2); // FIAS output buffer (before o_proj)
96
+ DeviceBuffer attn_out_dev(S * D * 2);
97
+
98
+ // ---- Causal mask (2048x2048) for prefill ----
99
+ const int64_t MASK = 2048;
100
+ DeviceBuffer mask_dev(MASK * MASK);
101
+ std::vector<uint8_t> mh(MASK * MASK, 0);
102
+ for (int i = 0; i < MASK; i++)
103
+ for (int j = i+1; j < MASK; j++) mh[i*MASK + j] = 1;
104
+ ACL_CHECK(aclrtMemcpy(mask_dev.get(), MASK*MASK, mh.data(), MASK*MASK, ACL_MEMCPY_HOST_TO_DEVICE));
105
+ auto t_mask = make_contig_tensor(mask_dev.get(), ACL_BOOL, {1, 1, MASK, MASK});
106
+
107
+ // ---- Attention forward ----
108
+ attention_forward(
109
+ rt.stream(), cfg, attn,
110
+ x_dev.get(), S,
111
+ /*past_len=*/0, k_cache.get(), v_cache.get(), MAX_LEN,
112
+ t_mask.get(),
113
+ q_sc.get(), k_sc.get(), v_sc.get(),
114
+ xn_sc.get(), rstd_sc.get(), rope_sc.get(),
115
+ attn_fias_sc.get(),
116
+ attn_out_dev.get());
117
+ rt.sync();
118
+
119
+ // ---- x1 = x + attn_out (residual) — should match attn_data/final_out.bin ----
120
+ DeviceBuffer x1_dev(S * D * 2);
121
+ auto t_attn_out = make_contig_tensor(attn_out_dev.get(), ACL_BF16, {S, D});
122
+ auto t_x1 = make_contig_tensor(x1_dev.get(), ACL_BF16, {S, D});
123
+ bf16_add(rt.stream(), t_x.get(), t_attn_out.get(), t_x1.get());
124
+ rt.sync();
125
+
126
+ auto attn_ref_h = read_file(attn_data + "/final_out.bin");
127
+ std::vector<uint16_t> x1_host(S * D);
128
+ ACL_CHECK(aclrtMemcpy(x1_host.data(), S*D*2, x1_dev.get(), S*D*2, ACL_MEMCPY_DEVICE_TO_HOST));
129
+ auto* ar = (const uint16_t*)attn_ref_h.data();
130
+ double al2d=0, al2r=0, amaxd=0;
131
+ for (int i = 0; i < S*D; i++) {
132
+ float a = bf16_to_float(x1_host[i]), b = bf16_to_float(ar[i]);
133
+ al2d += (a-b)*(a-b); al2r += b*b;
134
+ if (std::abs(a-b) > amaxd) amaxd = std::abs(a-b);
135
+ }
136
+ double arel = std::sqrt(al2d) / (std::sqrt(al2r) + 1e-10);
137
+ printf(" [attn] x + attn_out vs attn_data/final_out.bin: rel=%.4e max=%.4f\n", arel, amaxd);
138
+
139
+ // ---- MoE scratch buffers ----
140
+ const int64_t TOTAL = S * K;
141
+ DeviceBuffer moe_xn(S * D * 2), moe_rstd(S * 4);
142
+ DeviceBuffer moe_logits(S * E * 2);
143
+ DeviceBuffer moe_topk_w(S * K * 2), moe_topk_idx(S * K * 4), moe_row_idx(S * K * 4);
144
+ DeviceBuffer moe_ex_x(TOTAL * D * 2), moe_ex_ri(TOTAL * 4), moe_tpe(E * 8);
145
+ DeviceBuffer moe_fwd(TOTAL * 8);
146
+ DeviceBuffer moe_gate(TOTAL * I * 2), moe_up(TOTAL * I * 2), moe_down(TOTAL * D * 2);
147
+ DeviceBuffer moe_packed(TOTAL * D * 2), moe_weighted(S * K * D * 2);
148
+ DeviceBuffer moe_out_dev(S * D * 2);
149
+
150
+ moe_forward(rt.stream(), cfg, attn, moe,
151
+ x1_dev.get(), S,
152
+ moe_xn.get(), moe_rstd.get(),
153
+ moe_logits.get(),
154
+ moe_topk_w.get(), moe_topk_idx.get(), moe_row_idx.get(),
155
+ moe_ex_x.get(), moe_ex_ri.get(), moe_tpe.get(),
156
+ moe_fwd.get(),
157
+ moe_gate.get(), moe_up.get(), moe_down.get(),
158
+ moe_packed.get(), moe_weighted.get(),
159
+ moe_out_dev.get());
160
+ rt.sync();
161
+
162
+ // ---- x2 = x1 + moe_out (residual) — should match moe_data/final_out.bin ----
163
+ DeviceBuffer x2_dev(S * D * 2);
164
+ auto t_moe_out = make_contig_tensor(moe_out_dev.get(), ACL_BF16, {S, D});
165
+ auto t_x2 = make_contig_tensor(x2_dev.get(), ACL_BF16, {S, D});
166
+ bf16_add(rt.stream(), t_x1.get(), t_moe_out.get(), t_x2.get());
167
+ rt.sync();
168
+
169
+ auto moe_ref_h = read_file(moe_data + "/final_out.bin");
170
+ std::vector<uint16_t> x2_host(S * D);
171
+ ACL_CHECK(aclrtMemcpy(x2_host.data(), S*D*2, x2_dev.get(), S*D*2, ACL_MEMCPY_DEVICE_TO_HOST));
172
+ auto* mr = (const uint16_t*)moe_ref_h.data();
173
+ double ml2d=0, ml2r=0, mmaxd=0;
174
+ for (int i = 0; i < S*D; i++) {
175
+ float a = bf16_to_float(x2_host[i]), b = bf16_to_float(mr[i]);
176
+ ml2d += (a-b)*(a-b); ml2r += b*b;
177
+ if (std::abs(a-b) > mmaxd) mmaxd = std::abs(a-b);
178
+ }
179
+ double mrel = std::sqrt(ml2d) / (std::sqrt(ml2r) + 1e-10);
180
+ printf(" [full] x1 + moe_out vs moe_data/final_out.bin: rel=%.4e max=%.4f\n", mrel, mmaxd);
181
+ printf(" x2[0, :4]: %.5f %.5f %.5f %.5f\n",
182
+ bf16_to_float(x2_host[0]), bf16_to_float(x2_host[1]), bf16_to_float(x2_host[2]), bf16_to_float(x2_host[3]));
183
+ printf(" ref[0, :4]: %.5f %.5f %.5f %.5f\n",
184
+ bf16_to_float(mr[0]), bf16_to_float(mr[1]), bf16_to_float(mr[2]), bf16_to_float(mr[3]));
185
+
186
+ // Tolerance: attn chain 5e-3 (tight, only linear ops); full layer 1e-1 (MoE's discrete topk
187
+ // routing amplifies BF16 noise — tiny input changes flip expert selection, magnifying output
188
+ // delta. End-to-end CLI correctness is validated by test_chat_flow.sh separately.)
189
+ bool pass = (arel < 5e-3) && (mrel < 1e-1);
190
+ printf("\n%s\n", pass ? "=== test_layer_forward PASS ===" : "=== test_layer_forward FAIL ===");
191
+ return pass ? 0 : 1;
192
+ }
tests/test_linear_hf.cpp ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // test_linear_hf.cpp — verify linear_hf (y = x @ W.T with HF [out, in] layout).
2
+ #include "acl_common.h"
3
+ #include "acl_runtime.h"
4
+ #include "aclnn_ops.h"
5
+
6
+ #include <cmath>
7
+ #include <cstdio>
8
+ #include <cstring>
9
+ #include <fstream>
10
+ #include <vector>
11
+
12
+ static float bf16_to_float(uint16_t x) {
13
+ uint32_t u = (uint32_t)x << 16; float f; std::memcpy(&f, &u, 4); return f;
14
+ }
15
+
16
+ int main() {
17
+ const std::string data = "tests/mm_data";
18
+ int64_t N = 0, D = 0, OUT = 0;
19
+ {
20
+ std::ifstream f(data + "/shape.txt"); std::string line;
21
+ while (std::getline(f, line)) {
22
+ auto eq = line.find('='); if (eq == std::string::npos) continue;
23
+ auto k = line.substr(0, eq); auto v = std::atoll(line.c_str() + eq + 1);
24
+ if (k == "N") N = v; else if (k == "D") D = v; else if (k == "OUT") OUT = v;
25
+ }
26
+ }
27
+ printf("N=%ld D=%ld OUT=%ld\n", N, D, OUT);
28
+
29
+ AclRuntime rt;
30
+ rt.init(0);
31
+
32
+ auto read_all = [&](const std::string& p) {
33
+ std::ifstream f(p, std::ios::binary | std::ios::ate); size_t sz = f.tellg();
34
+ f.seekg(0); std::vector<uint8_t> v(sz); f.read((char*)v.data(), sz); return v;
35
+ };
36
+ auto x_h = read_all(data + "/x.bin");
37
+ auto W_h = read_all(data + "/W.bin");
38
+ auto yr_h = read_all(data + "/y_ref.bin");
39
+
40
+ DeviceBuffer x_d(N * D * 2);
41
+ DeviceBuffer W_d(OUT * D * 2);
42
+ DeviceBuffer y_d(N * OUT * 2);
43
+ ACL_CHECK(aclrtMemcpy(x_d.get(), x_h.size(), x_h.data(), x_h.size(), ACL_MEMCPY_HOST_TO_DEVICE));
44
+ ACL_CHECK(aclrtMemcpy(W_d.get(), W_h.size(), W_h.data(), W_h.size(), ACL_MEMCPY_HOST_TO_DEVICE));
45
+
46
+ auto t_x = make_contig_tensor(x_d.get(), ACL_BF16, {N, D});
47
+ auto t_y = make_contig_tensor(y_d.get(), ACL_BF16, {N, OUT});
48
+
49
+ linear_hf(rt.stream(), t_x.get(), W_d.get(), ACL_BF16, OUT, D, t_y.get());
50
+ rt.sync();
51
+
52
+ std::vector<uint16_t> y_cxx(N * OUT);
53
+ ACL_CHECK(aclrtMemcpy(y_cxx.data(), N * OUT * 2, y_d.get(), N * OUT * 2, ACL_MEMCPY_DEVICE_TO_HOST));
54
+ auto* y_ref = (const uint16_t*)yr_h.data();
55
+
56
+ double l2d = 0, l2r = 0, maxd = 0;
57
+ for (int i = 0; i < N * OUT; i++) {
58
+ float a = bf16_to_float(y_cxx[i]);
59
+ float b = bf16_to_float(y_ref[i]);
60
+ l2d += (a-b)*(a-b); l2r += b*b;
61
+ if (std::abs(a-b) > maxd) maxd = std::abs(a-b);
62
+ }
63
+ double rel = std::sqrt(l2d) / (std::sqrt(l2r) + 1e-10);
64
+ printf("L2 diff=%.4f ref=%.4f relative=%.4e max_abs=%.4f\n",
65
+ std::sqrt(l2d), std::sqrt(l2r), rel, maxd);
66
+ printf("y_cxx[0..3]: "); for (int i = 0; i < 4; i++) printf("%.3f ", bf16_to_float(y_cxx[i])); printf("\n");
67
+ printf("y_ref[0..3]: "); for (int i = 0; i < 4; i++) printf("%.3f ", bf16_to_float(y_ref[i])); printf("\n");
68
+
69
+ // BF16 matmul has more precision loss than RmsNorm. Allow 1% relative error.
70
+ bool ok = rel < 1e-2;
71
+ printf("\n%s\n", ok ? "=== test_linear_hf PASS ===" : "=== test_linear_hf FAIL ===");
72
+ return ok ? 0 : 1;
73
+ }
tests/test_model_config.cpp ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // test_model_config.cpp — load config.json, derive TP shard sizes, verify all expected
2
+ // HF tensors exist in safetensors for Qwen3-235B.
3
+ #include "model_config.h"
4
+ #include "safetensors_loader.h"
5
+
6
+ #include <cstdio>
7
+ #include <sstream>
8
+ #include <string>
9
+
10
+ int main(int argc, char** argv) {
11
+ std::string dir = argc > 1 ? argv[1]
12
+ : "/path/to/Qwen3-235B-A22B-Instruct-2507-BF16";
13
+ int tp_size = argc > 2 ? std::atoi(argv[2]) : 16;
14
+ int tp_rank = argc > 3 ? std::atoi(argv[3]) : 0;
15
+
16
+ ModelConfig cfg;
17
+ if (!cfg.load_from_json(dir + "/config.json")) return 1;
18
+ cfg.compute_derived(tp_size, tp_rank);
19
+ printf("%s\n", cfg.describe().c_str());
20
+
21
+ SafetensorsLoader loader;
22
+ if (!loader.open(dir)) return 1;
23
+
24
+ // Verify all expected tensor names & shapes match cfg.
25
+ int missing = 0, shape_mismatch = 0;
26
+ auto check_shape = [&](const std::string& name, const std::vector<int64_t>& expected) {
27
+ auto* m = loader.get(name);
28
+ if (!m) {
29
+ printf(" MISSING: %s\n", name.c_str());
30
+ missing++;
31
+ return;
32
+ }
33
+ if (m->shape != expected) {
34
+ printf(" SHAPE MISMATCH: %s got=[", name.c_str());
35
+ for (size_t i = 0; i < m->shape.size(); i++) printf("%s%ld", i ? "," : "", m->shape[i]);
36
+ printf("] want=[");
37
+ for (size_t i = 0; i < expected.size(); i++) printf("%s%ld", i ? "," : "", expected[i]);
38
+ printf("]\n");
39
+ shape_mismatch++;
40
+ }
41
+ };
42
+
43
+ // embed/head
44
+ check_shape("model.embed_tokens.weight", {cfg.vocab_size, cfg.hidden_size});
45
+ check_shape("lm_head.weight", {cfg.vocab_size, cfg.hidden_size});
46
+ check_shape("model.norm.weight", {cfg.hidden_size});
47
+
48
+ // Attention weights (HF stores as [out, in])
49
+ int64_t q_full = cfg.num_attention_heads * cfg.head_dim;
50
+ int64_t kv_full = cfg.num_key_value_heads * cfg.head_dim;
51
+ for (int L = 0; L < cfg.num_hidden_layers; L++) {
52
+ auto base = "model.layers." + std::to_string(L);
53
+ check_shape(base + ".input_layernorm.weight", {cfg.hidden_size});
54
+ check_shape(base + ".post_attention_layernorm.weight", {cfg.hidden_size});
55
+ check_shape(base + ".self_attn.q_proj.weight", {q_full, cfg.hidden_size});
56
+ check_shape(base + ".self_attn.k_proj.weight", {kv_full, cfg.hidden_size});
57
+ check_shape(base + ".self_attn.v_proj.weight", {kv_full, cfg.hidden_size});
58
+ check_shape(base + ".self_attn.o_proj.weight", {cfg.hidden_size, q_full});
59
+ // Qwen3 uses q_norm / k_norm (norm per head) — check existence
60
+ check_shape(base + ".self_attn.q_norm.weight", {cfg.head_dim});
61
+ check_shape(base + ".self_attn.k_norm.weight", {cfg.head_dim});
62
+ // MoE router
63
+ check_shape(base + ".mlp.gate.weight", {cfg.num_experts, cfg.hidden_size});
64
+ // Spot-check few experts (full enumeration is 94*384=36096 lines)
65
+ for (int e : {0, 1, 63, 127}) {
66
+ auto ebase = base + ".mlp.experts." + std::to_string(e);
67
+ check_shape(ebase + ".gate_proj.weight", {cfg.moe_intermediate_size, cfg.hidden_size});
68
+ check_shape(ebase + ".up_proj.weight", {cfg.moe_intermediate_size, cfg.hidden_size});
69
+ check_shape(ebase + ".down_proj.weight", {cfg.hidden_size, cfg.moe_intermediate_size});
70
+ }
71
+ }
72
+
73
+ // Print TP memory estimate
74
+ int64_t attn_bytes_per_rank = 0;
75
+ attn_bytes_per_rank += cfg.q_dim_per_rank * cfg.hidden_size * 2; // q_proj
76
+ attn_bytes_per_rank += cfg.kv_dim_per_rank * cfg.hidden_size * 2; // k_proj
77
+ attn_bytes_per_rank += cfg.kv_dim_per_rank * cfg.hidden_size * 2; // v_proj
78
+ attn_bytes_per_rank += cfg.hidden_size * cfg.q_dim_per_rank * 2; // o_proj
79
+ attn_bytes_per_rank *= cfg.num_hidden_layers;
80
+
81
+ int64_t moe_bytes_per_rank = 0;
82
+ // gate_exps + up_exps: [E, I_per_rank, D]
83
+ moe_bytes_per_rank += 2 * cfg.num_experts * cfg.i_per_rank * cfg.hidden_size * 2;
84
+ // down_exps: [E, D, I_per_rank]
85
+ moe_bytes_per_rank += cfg.num_experts * cfg.hidden_size * cfg.i_per_rank * 2;
86
+ moe_bytes_per_rank *= cfg.num_hidden_layers;
87
+
88
+ int64_t embed_bytes = cfg.vocab_size * cfg.hidden_size * 2 * 2; // embed + lm_head
89
+ int64_t router_bytes = cfg.num_experts * cfg.hidden_size * 2 * cfg.num_hidden_layers;
90
+ int64_t norm_bytes = cfg.hidden_size * 2 * (2 * cfg.num_hidden_layers + 1);
91
+ int64_t total_per_rank = attn_bytes_per_rank + moe_bytes_per_rank + embed_bytes + router_bytes + norm_bytes;
92
+
93
+ printf("\nPer-rank weight memory estimate (BF16, TP=%d):\n", tp_size);
94
+ printf(" attention: %.2f GB\n", attn_bytes_per_rank / 1e9);
95
+ printf(" MoE exps: %.2f GB\n", moe_bytes_per_rank / 1e9);
96
+ printf(" embed+head: %.2f GB (replicated)\n", embed_bytes / 1e9);
97
+ printf(" router: %.2f MB (replicated)\n", router_bytes / 1e6);
98
+ printf(" norms: %.2f MB (replicated)\n", norm_bytes / 1e6);
99
+ printf(" TOTAL: %.2f GB\n", total_per_rank / 1e9);
100
+
101
+ int errors = missing + shape_mismatch;
102
+ printf("\nMissing: %d, Shape mismatch: %d\n", missing, shape_mismatch);
103
+ printf("%s\n", errors == 0 ? "=== test_model_config PASS ==="
104
+ : "=== test_model_config FAIL ===");
105
+ return errors == 0 ? 0 : 1;
106
+ }
tests/test_moe_layer.cpp ADDED
@@ -0,0 +1,676 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // test_moe_layer.cpp — Full MoE layer forward (Qwen3-235B layer 0), TP=1.
2
+ //
3
+ // Pipeline:
4
+ // 1. Post-attention RmsNorm (input from attn_data/final_out.bin)
5
+ // 2. Router: xn @ W_router.T → logits [S, E]
6
+ // 3. TopK softmax → weights [S, K], expert_ids [S, K]
7
+ // 4. Host-normalize top_k weights (Qwen3 norm_topk_prob)
8
+ // 5. MoeInitRoutingV3 → expanded_x [S*K, D], expanded_row_idx, tokens_per_expert
9
+ // 6. GMM gate: expanded_x × gate_exps → [S*K, I]
10
+ // 7. GMM up: same → [S*K, I]
11
+ // 8. silu(gate) * up → [S*K, I]
12
+ // 9. GMM down: act × down_exps → [S*K, D]
13
+ // 10. MoeFinalizeRouting (weighted sum) → [S, D]
14
+ // 11. + residual
15
+ #include "acl_common.h"
16
+ #include "acl_runtime.h"
17
+ #include "aclnn_ops.h"
18
+ #include "device_weights.h"
19
+ #include "model_config.h"
20
+ #include "safetensors_loader.h"
21
+
22
+ #include <algorithm>
23
+ #include <cmath>
24
+ #include <cstdio>
25
+ #include <cstring>
26
+ #include <fstream>
27
+ #include <tuple>
28
+ #include <vector>
29
+
30
+ static float bf16_to_float(uint16_t x) {
31
+ uint32_t u = (uint32_t)x << 16; float f; std::memcpy(&f, &u, 4); return f;
32
+ }
33
+ static uint16_t float_to_bf16(float x) {
34
+ uint32_t u; std::memcpy(&u, &x, 4);
35
+ return (uint16_t)((u + 0x7FFF + ((u >> 16) & 1)) >> 16);
36
+ }
37
+ static std::vector<uint8_t> read_file(const std::string& p) {
38
+ std::ifstream f(p, std::ios::binary | std::ios::ate); size_t s = f.tellg();
39
+ f.seekg(0); std::vector<uint8_t> v(s); f.read((char*)v.data(), s); return v;
40
+ }
41
+
42
+ int main() {
43
+ const std::string model_dir = "/path/to/Qwen3-235B-A22B-Instruct-2507-BF16";
44
+ const std::string data_dir = "tests/moe_data";
45
+
46
+ ModelConfig cfg;
47
+ if (!cfg.load_from_json(model_dir + "/config.json")) return 1;
48
+ cfg.compute_derived(1, 0); // TP=1
49
+ const int64_t D = cfg.hidden_size;
50
+ const int64_t I = cfg.moe_intermediate_size;
51
+ const int64_t E = cfg.num_experts;
52
+ const int64_t K = cfg.num_experts_per_tok;
53
+ const double eps = cfg.rms_norm_eps;
54
+
55
+ AclRuntime rt;
56
+ rt.init(0);
57
+ printf("[dbg] rt init ok\n"); fflush(stdout);
58
+
59
+ SafetensorsLoader st;
60
+ if (!st.open(model_dir)) return 1;
61
+
62
+ // ---- Load weights ----
63
+ printf("Loading layer 0 attention weights (for post_attention_layernorm)...\n");
64
+ DeviceWeightsLoader dw(st, cfg);
65
+ LayerAttnWeights attn;
66
+ if (!dw.load_attention(0, attn)) return 1;
67
+
68
+ printf("Loading layer 0 MoE weights (128 experts × 3 projections, stacking + permute)...\n"); fflush(stdout);
69
+ LayerMoEWeights moe;
70
+ if (!dw.load_moe(0, rt.stream(), moe)) return 1;
71
+ rt.sync();
72
+ printf("[dbg] moe load ok\n"); fflush(stdout);
73
+ printf(" router %.1f MB gate_exps %.0f MB up_exps %.0f MB down_exps %.0f MB\n",
74
+ moe.router.size / 1e6, moe.gate_exps.size / 1e6, moe.up_exps.size / 1e6, moe.down_exps.size / 1e6);
75
+
76
+ // ---- Load input & Python reference ----
77
+ int S = 5;
78
+ auto x_in_host = read_file(data_dir + "/x_in.bin");
79
+ auto ref_out_host = read_file(data_dir + "/final_out.bin");
80
+ DeviceBuffer x_dev(S * D * 2);
81
+ ACL_CHECK(aclrtMemcpy(x_dev.get(), x_in_host.size(), x_in_host.data(), x_in_host.size(), ACL_MEMCPY_HOST_TO_DEVICE));
82
+
83
+ // Residual snapshot
84
+ DeviceBuffer residual_dev(S * D * 2);
85
+ ACL_CHECK(aclrtMemcpy(residual_dev.get(), S*D*2, x_dev.get(), S*D*2, ACL_MEMCPY_DEVICE_TO_DEVICE));
86
+
87
+ printf("[dbg] loaded data and residual ok, TOTAL=%ld\n", S * K); fflush(stdout);
88
+
89
+ // ---- Step 1: Post-attention RmsNorm ----
90
+ DeviceBuffer xn_dev(S * D * 2);
91
+ DeviceBuffer rstd_dev(S * 4);
92
+ auto t_x = make_contig_tensor(x_dev.get(), ACL_BF16, {S, D});
93
+ auto t_xn = make_contig_tensor(xn_dev.get(), ACL_BF16, {S, D});
94
+ auto t_ln = make_contig_tensor(attn.post_attention_layernorm.get(), ACL_BF16, {D});
95
+ auto t_rstd = make_contig_tensor(rstd_dev.get(), ACL_FLOAT, {S});
96
+ rms_norm(rt.stream(), t_x.get(), t_ln.get(), eps, t_xn.get(), t_rstd.get());
97
+ rt.sync();
98
+ printf("[dbg] rms_norm ok\n"); fflush(stdout);
99
+
100
+ // ---- Step 2: Router (gate matmul) ----
101
+ DeviceBuffer logits_dev(S * E * 2);
102
+ auto t_logits = make_contig_tensor(logits_dev.get(), ACL_BF16, {S, E});
103
+ // router is [E, D] (HF). logits = xn @ router.T
104
+ linear_hf(rt.stream(), t_xn.get(), moe.router.get(), ACL_BF16, E, D, t_logits.get());
105
+ rt.sync();
106
+ printf("[dbg] router linear ok\n"); fflush(stdout);
107
+
108
+ // ---- Step 3: TopK softmax ----
109
+ DeviceBuffer topk_w_dev(S * K * 2); // BF16
110
+ DeviceBuffer topk_idx_dev(S * K * 4); // int32
111
+ DeviceBuffer row_idx_dev(S * K * 4); // int32 (from gating op, unused for our routing)
112
+ auto t_topk_w = make_contig_tensor(topk_w_dev.get(), ACL_BF16, {S, K});
113
+ auto t_topk_idx = make_contig_tensor(topk_idx_dev.get(), ACL_INT32, {S, K});
114
+ auto t_row_idx = make_contig_tensor(row_idx_dev.get(), ACL_INT32, {S, K});
115
+ moe_gating_topk_softmax(rt.stream(), t_logits.get(), K, t_topk_w.get(), t_topk_idx.get(), t_row_idx.get());
116
+ rt.sync();
117
+ printf("[dbg] topk_softmax ok\n"); fflush(stdout);
118
+
119
+ // ---- Step 4: Host-normalize top_k weights (norm_topk_prob=true) ----
120
+ std::vector<uint16_t> tw_bf(S * K);
121
+ ACL_CHECK(aclrtMemcpy(tw_bf.data(), S*K*2, topk_w_dev.get(), S*K*2, ACL_MEMCPY_DEVICE_TO_HOST));
122
+ for (int s = 0; s < S; s++) {
123
+ float sum = 0.0f;
124
+ for (int k = 0; k < K; k++) sum += bf16_to_float(tw_bf[s*K + k]);
125
+ sum += 1e-20f;
126
+ for (int k = 0; k < K; k++) {
127
+ float v = bf16_to_float(tw_bf[s*K + k]) / sum;
128
+ tw_bf[s*K + k] = float_to_bf16(v);
129
+ }
130
+ }
131
+ ACL_CHECK(aclrtMemcpy(topk_w_dev.get(), S*K*2, tw_bf.data(), S*K*2, ACL_MEMCPY_HOST_TO_DEVICE));
132
+
133
+ // ---- Step 5: MoE init routing ----
134
+ int64_t TOTAL = S * K;
135
+ DeviceBuffer expanded_x_dev(TOTAL * D * 2);
136
+ DeviceBuffer expanded_row_idx_dev(TOTAL * 4);
137
+ DeviceBuffer tokens_per_expert_dev(E * 8);
138
+
139
+ auto t_ex_x = make_contig_tensor(expanded_x_dev.get(), ACL_BF16, {TOTAL, D});
140
+ auto t_ex_ri = make_contig_tensor(expanded_row_idx_dev.get(), ACL_INT32, {TOTAL});
141
+ auto t_tpe = make_contig_tensor(tokens_per_expert_dev.get(), ACL_INT64, {E});
142
+
143
+ moe_init_routing_v3(rt.stream(),
144
+ t_xn.get(), t_topk_idx.get(),
145
+ E, TOTAL,
146
+ t_ex_x.get(), t_ex_ri.get(), t_tpe.get());
147
+ rt.sync();
148
+ printf("[dbg] moe_init_routing ok\n"); fflush(stdout);
149
+
150
+ // Convert tokens_per_expert from counts to cumsum (on host) for GMM groupListType=0.
151
+ DeviceBuffer tpe_cumsum_dev(E * 8);
152
+ {
153
+ std::vector<int64_t> h_counts(E), h_cum(E);
154
+ ACL_CHECK(aclrtMemcpy(h_counts.data(), E*8, tokens_per_expert_dev.get(), E*8, ACL_MEMCPY_DEVICE_TO_HOST));
155
+ int64_t acc = 0;
156
+ for (int i = 0; i < E; i++) { acc += h_counts[i]; h_cum[i] = acc; }
157
+ ACL_CHECK(aclrtMemcpy(tpe_cumsum_dev.get(), E*8, h_cum.data(), E*8, ACL_MEMCPY_HOST_TO_DEVICE));
158
+ }
159
+ auto t_tpe_cum = make_contig_tensor(tpe_cumsum_dev.get(), ACL_INT64, {E});
160
+
161
+ // ---- Step 6/7: GMM gate and up ----
162
+ DeviceBuffer gate_out_dev(TOTAL * I * 2);
163
+ DeviceBuffer up_out_dev(TOTAL * I * 2);
164
+ auto t_gate_out = make_contig_tensor(gate_out_dev.get(), ACL_BF16, {TOTAL, I});
165
+ auto t_up_out = make_contig_tensor(up_out_dev.get(), ACL_BF16, {TOTAL, I});
166
+ // gate/up_exps loaded as [E, D, I] row-major
167
+ auto t_w_gate = make_contig_tensor(moe.gate_exps.get(), ACL_BF16, {E, D, I});
168
+ auto t_w_up = make_contig_tensor(moe.up_exps.get(), ACL_BF16, {E, D, I});
169
+ // Use cumsum group_list (groupListType=0): empirically more reliable with many zero-count experts.
170
+ grouped_matmul_v4(rt.stream(), t_ex_x.get(), t_w_gate.get(), t_tpe_cum.get(), t_gate_out.get(), 0);
171
+ rt.sync();
172
+ printf("[dbg] gmm gate ok\n"); fflush(stdout);
173
+ grouped_matmul_v4(rt.stream(), t_ex_x.get(), t_w_up.get(), t_tpe_cum.get(), t_up_out.get(), 0);
174
+ rt.sync();
175
+ printf("[dbg] gmm up ok\n"); fflush(stdout);
176
+
177
+ // ---- Step 8: SwiGLU ----
178
+ // act = silu(gate) * up (inplace on gate_out)
179
+ silu(rt.stream(), t_gate_out.get(), t_gate_out.get());
180
+ rt.sync(); printf("[dbg] silu ok\n"); fflush(stdout);
181
+ mul(rt.stream(), t_gate_out.get(), t_up_out.get(), t_gate_out.get());
182
+ rt.sync(); printf("[dbg] mul ok\n"); fflush(stdout);
183
+ // now gate_out_dev contains the activated intermediate
184
+
185
+ // ---- Step 9: GMM down ----
186
+ DeviceBuffer down_out_dev(TOTAL * D * 2);
187
+ auto t_down_out = make_contig_tensor(down_out_dev.get(), ACL_BF16, {TOTAL, D});
188
+ auto t_w_down = make_contig_tensor(moe.down_exps.get(), ACL_BF16, {E, I, D});
189
+ grouped_matmul_v4(rt.stream(), t_gate_out.get(), t_w_down.get(), t_tpe_cum.get(), t_down_out.get(), 0);
190
+ rt.sync();
191
+ printf("[dbg] gmm down ok\n"); fflush(stdout);
192
+
193
+ // ---- Step 10: Device-side manual finalize (replacement for buggy MoeFinalizeRoutingV2) ----
194
+ // Compute forward permutation fwd[n*K + k] = p where token n's k-th expert's output is at
195
+ // expanded position p. We use tokens_per_expert (cumsum) + topk_idx to resolve this correctly,
196
+ // regardless of the exact rowIdxType semantics returned by MoeInitRoutingV3.
197
+ DeviceBuffer fwd_dev(TOTAL * 8);
198
+ {
199
+ std::vector<int64_t> h_tpe2(E);
200
+ std::vector<int32_t> h_tidx3(S * K);
201
+ ACL_CHECK(aclrtMemcpy(h_tpe2.data(), E*8, tokens_per_expert_dev.get(), E*8, ACL_MEMCPY_DEVICE_TO_HOST));
202
+ ACL_CHECK(aclrtMemcpy(h_tidx3.data(), S*K*4, topk_idx_dev.get(), S*K*4, ACL_MEMCPY_DEVICE_TO_HOST));
203
+
204
+ // Sort (n, k) pairs by expert ascending (stable). For each expert in order, tokens
205
+ // appear in ascending token index (since MoeInitRoutingV3 is stable by s).
206
+ // Specifically: expanded positions 0..tpe[0]-1 are for expert 0 (tokens picking e=0, in n-ascending order),
207
+ // next tpe[1] are for expert 1, etc.
208
+ //
209
+ // To build fwd: for each (n, k), expert e = topk_idx[n, k]. Position p is the base of expert e's
210
+ // block plus the rank of n within tokens picking e.
211
+ std::vector<int64_t> expert_base(E + 1, 0);
212
+ for (int e = 0; e < E; e++) expert_base[e + 1] = expert_base[e] + h_tpe2[e];
213
+
214
+ std::vector<int> expert_slot(E, 0); // next available slot per expert
215
+ std::vector<int64_t> fwd(TOTAL);
216
+ // Iterate in token-ascending, k-ascending order — match MoeInitRoutingV3's stable sort convention.
217
+ // For each (n, k) sorted by (expert[n,k], n), assign p.
218
+ // Simpler: pre-collect (e, n, k) triples, sort by (e, n), then p is the rank.
219
+ std::vector<std::tuple<int, int, int>> triples;
220
+ triples.reserve(TOTAL);
221
+ for (int n = 0; n < S; n++) for (int k = 0; k < K; k++) {
222
+ triples.emplace_back(h_tidx3[n * K + k], n, k);
223
+ }
224
+ std::sort(triples.begin(), triples.end(), [](const auto& a, const auto& b){
225
+ if (std::get<0>(a) != std::get<0>(b)) return std::get<0>(a) < std::get<0>(b);
226
+ return std::get<1>(a) < std::get<1>(b);
227
+ });
228
+ for (int64_t p = 0; p < TOTAL; p++) {
229
+ auto [e, n, k] = triples[p];
230
+ fwd[n * K + k] = p;
231
+ }
232
+ ACL_CHECK(aclrtMemcpy(fwd_dev.get(), TOTAL*8, fwd.data(), TOTAL*8, ACL_MEMCPY_HOST_TO_DEVICE));
233
+ }
234
+ auto t_fwd = make_contig_tensor(fwd_dev.get(), ACL_INT64, {TOTAL});
235
+
236
+ // Gather: packed [S*K, D] = down_out[fwd, :]
237
+ DeviceBuffer packed_dev(TOTAL * D * 2);
238
+ auto t_packed = make_contig_tensor(packed_dev.get(), ACL_BF16, {TOTAL, D});
239
+ index_select(rt.stream(), t_down_out.get(), 0, t_fwd.get(), t_packed.get());
240
+ rt.sync();
241
+
242
+ // Broadcast-multiply by topk_w: view packed as [S, K, D], topk_w as [S, K, 1].
243
+ auto t_packed_3d = make_contig_tensor(packed_dev.get(), ACL_BF16, {S, K, D});
244
+ auto t_topk_w_3d = make_contig_tensor(topk_w_dev.get(), ACL_BF16, {S, K, 1});
245
+ DeviceBuffer weighted_dev(S * K * D * 2);
246
+ auto t_weighted = make_contig_tensor(weighted_dev.get(), ACL_BF16, {S, K, D});
247
+ mul(rt.stream(), t_packed_3d.get(), t_topk_w_3d.get(), t_weighted.get());
248
+ rt.sync();
249
+
250
+ // Verify broadcast mul + sum by dumping all k entries and summing on host.
251
+ {
252
+ std::vector<uint16_t> h_pk_all(S * K * D);
253
+ std::vector<uint16_t> h_wt_all(S * K * D);
254
+ std::vector<uint16_t> h_tw_all(S * K);
255
+ ACL_CHECK(aclrtMemcpy(h_pk_all.data(), S*K*D*2, packed_dev.get(), S*K*D*2, ACL_MEMCPY_DEVICE_TO_HOST));
256
+ ACL_CHECK(aclrtMemcpy(h_wt_all.data(), S*K*D*2, weighted_dev.get(), S*K*D*2, ACL_MEMCPY_DEVICE_TO_HOST));
257
+ ACL_CHECK(aclrtMemcpy(h_tw_all.data(), S*K*2, topk_w_dev.get(), S*K*2, ACL_MEMCPY_DEVICE_TO_HOST));
258
+
259
+ printf(" verify weighted[0, k, 0] = packed[0, k, 0] * topk_w[0, k] for all k:\n");
260
+ float host_sum = 0;
261
+ for (int k = 0; k < K; k++) {
262
+ float p = bf16_to_float(h_pk_all[k * D]); // packed[0, k, 0] = offset s*K*D + k*D + 0 = k*D (for s=0)
263
+ float w = bf16_to_float(h_tw_all[k]); // topk_w[0, k]
264
+ float wt = bf16_to_float(h_wt_all[k * D]); // weighted[0, k, 0]
265
+ host_sum += p * w;
266
+ printf(" k=%d: packed=%.5f * topk_w=%.5f = expect=%.5f dev=%.5f\n",
267
+ k, p, w, p*w, wt);
268
+ }
269
+ printf(" host_sum_of_weighted[0, :, 0] = %.5f (expected moe_out[0,0] = -0.02466)\n", host_sum);
270
+ }
271
+
272
+ // ReduceSum over K axis → [S, D]
273
+ DeviceBuffer moe_out_dev(S * D * 2);
274
+ auto t_moe_out = make_contig_tensor(moe_out_dev.get(), ACL_BF16, {S, D});
275
+ reduce_sum(rt.stream(), t_weighted.get(), {1}, /*keep_dims=*/false, ACL_BF16, t_moe_out.get());
276
+ rt.sync();
277
+ printf("[dbg] device-side finalize (gather+mul+reduce) ok\n"); fflush(stdout);
278
+
279
+ // Residual add to produce final_out
280
+ float alpha_v = 1.0f; aclScalar* alpha = aclCreateScalar(&alpha_v, ACL_FLOAT);
281
+ DeviceBuffer final_dev(S * D * 2);
282
+ auto t_final = make_contig_tensor(final_dev.get(), ACL_BF16, {S, D});
283
+ auto t_res = make_contig_tensor(residual_dev.get(), ACL_BF16, {S, D});
284
+ {
285
+ uint64_t ws = 0; aclOpExecutor* e = nullptr;
286
+ ACLNN_CHECK(aclnnAddGetWorkspaceSize(t_res.get(), t_moe_out.get(), alpha, t_final.get(), &ws, &e));
287
+ DeviceBuffer wb; if (ws > 0) wb.alloc(ws);
288
+ ACLNN_CHECK(aclnnAdd(wb.get(), ws, e, rt.stream()));
289
+ }
290
+ aclDestroyScalar(alpha);
291
+ rt.sync();
292
+
293
+ // ---- Compare (intermediate + final) ----
294
+ auto compare_bf16 = [&](const char* label, void* dev_ptr, int64_t nelem,
295
+ const std::string& ref_file) {
296
+ std::vector<uint16_t> cxx(nelem);
297
+ ACL_CHECK(aclrtMemcpy(cxx.data(), nelem*2, dev_ptr, nelem*2, ACL_MEMCPY_DEVICE_TO_HOST));
298
+ auto refbuf = read_file(data_dir + "/" + ref_file);
299
+ auto* ref = (const uint16_t*)refbuf.data();
300
+ double l2d = 0, l2r = 0, maxd = 0;
301
+ for (int64_t i = 0; i < nelem; i++) {
302
+ float a = bf16_to_float(cxx[i]), b = bf16_to_float(ref[i]);
303
+ l2d += (a-b)*(a-b); l2r += b*b;
304
+ if (std::abs(a-b) > maxd) maxd = std::abs(a-b);
305
+ }
306
+ double rel = std::sqrt(l2d) / (std::sqrt(l2r) + 1e-10);
307
+ printf(" [cmp] %-12s rel=%.4e max_abs=%.4f cxx[:4]=%.5f %.5f %.5f %.5f ref[:4]=%.5f %.5f %.5f %.5f\n",
308
+ label, rel, maxd,
309
+ bf16_to_float(cxx[0]), bf16_to_float(cxx[1]), bf16_to_float(cxx[2]), bf16_to_float(cxx[3]),
310
+ bf16_to_float(ref[0]), bf16_to_float(ref[1]), bf16_to_float(ref[2]), bf16_to_float(ref[3]));
311
+ return rel;
312
+ };
313
+
314
+ printf("\n=== Intermediate diagnostics ===\n");
315
+ compare_bf16("xn", xn_dev.get(), S * D, "xn.bin");
316
+ compare_bf16("topk_w", topk_w_dev.get(), S * K, "topk_w.bin");
317
+
318
+ // Dump topk_idx (int32) to compare
319
+ {
320
+ std::vector<int32_t> cxx_idx(S*K);
321
+ ACL_CHECK(aclrtMemcpy(cxx_idx.data(), S*K*4, topk_idx_dev.get(), S*K*4, ACL_MEMCPY_DEVICE_TO_HOST));
322
+ auto refbuf = read_file(data_dir + "/topk_idx.bin");
323
+ auto* ref = (const int32_t*)refbuf.data();
324
+ int mismatches = 0;
325
+ for (int i = 0; i < S*K; i++) if (cxx_idx[i] != ref[i]) mismatches++;
326
+ printf(" [cmp] topk_idx mismatches=%d/%d cxx[0,:4]=%d %d %d %d ref[0,:4]=%d %d %d %d\n",
327
+ mismatches, S*K,
328
+ cxx_idx[0], cxx_idx[1], cxx_idx[2], cxx_idx[3],
329
+ ref[0], ref[1], ref[2], ref[3]);
330
+ }
331
+
332
+ printf("\n=== MoE-only (before residual) ===\n");
333
+ compare_bf16("moe_out", moe_out_dev.get(), S * D, "out_flat.bin");
334
+
335
+ // Manual host-side finalize: verify what down_out + expanded_row_idx + topk_w produce.
336
+ {
337
+ std::vector<uint16_t> h_down(TOTAL * D);
338
+ std::vector<int32_t> h_ri(TOTAL);
339
+ std::vector<uint16_t> h_tw(S * K);
340
+ ACL_CHECK(aclrtMemcpy(h_down.data(), TOTAL*D*2, down_out_dev.get(), TOTAL*D*2, ACL_MEMCPY_DEVICE_TO_HOST));
341
+ ACL_CHECK(aclrtMemcpy(h_ri.data(), TOTAL*4, expanded_row_idx_dev.get(), TOTAL*4, ACL_MEMCPY_DEVICE_TO_HOST));
342
+ ACL_CHECK(aclrtMemcpy(h_tw.data(), S*K*2, topk_w_dev.get(), S*K*2, ACL_MEMCPY_DEVICE_TO_HOST));
343
+
344
+ printf(" expanded_row_idx (all %ld):\n ", TOTAL);
345
+ for (int i = 0; i < TOTAL; i++) {
346
+ printf("%d ", h_ri[i]);
347
+ if ((i+1) % 10 == 0) printf("\n ");
348
+ }
349
+ printf("\n");
350
+ // count unique and check bijection
351
+ std::vector<int> count(TOTAL, 0);
352
+ int out_of_range = 0;
353
+ for (int i = 0; i < TOTAL; i++) {
354
+ int v = h_ri[i];
355
+ if (v >= 0 && v < TOTAL) count[v]++;
356
+ else out_of_range++;
357
+ }
358
+ int bijection_ok = (out_of_range == 0);
359
+ for (int i = 0; i < TOTAL && bijection_ok; i++) if (count[i] != 1) bijection_ok = 0;
360
+ printf(" bijection=%s out_of_range=%d\n", bijection_ok ? "YES" : "NO", out_of_range);
361
+
362
+ // Also dump tokens_per_expert (int64) — should sum to TOTAL
363
+ std::vector<int64_t> h_tpe(E);
364
+ ACL_CHECK(aclrtMemcpy(h_tpe.data(), E*8, tokens_per_expert_dev.get(), E*8, ACL_MEMCPY_DEVICE_TO_HOST));
365
+ int64_t tpe_sum = 0, nonzero = 0;
366
+ int64_t tpe_max = 0;
367
+ for (int i = 0; i < E; i++) { tpe_sum += h_tpe[i]; if (h_tpe[i]>0) nonzero++; if (h_tpe[i]>tpe_max) tpe_max=h_tpe[i]; }
368
+ printf(" tokens_per_expert: sum=%ld nonzero=%ld max=%ld (expected sum=%ld if counts, or last=%ld if cumsum)\n",
369
+ tpe_sum, nonzero, tpe_max, TOTAL, TOTAL);
370
+ printf(" tpe[last 4]: %ld %ld %ld %ld\n", h_tpe[E-4], h_tpe[E-3], h_tpe[E-2], h_tpe[E-1]);
371
+
372
+ std::vector<float> manual(S * D, 0.0f);
373
+ for (int64_t p = 0; p < TOTAL; p++) {
374
+ int32_t src = h_ri[p];
375
+ int s = src / K;
376
+ int k = src % K;
377
+ if (s < 0 || s >= S || k < 0 || k >= K) { printf(" bad idx p=%ld src=%d\n", p, src); continue; }
378
+ float w = bf16_to_float(h_tw[s * K + k]);
379
+ for (int d = 0; d < D; d++) {
380
+ manual[s * D + d] += w * bf16_to_float(h_down[p * D + d]);
381
+ }
382
+ }
383
+ // Convert to bf16 and compare to Python out_flat
384
+ auto refbuf = read_file(data_dir + "/out_flat.bin");
385
+ auto* ref = (const uint16_t*)refbuf.data();
386
+ double l2d=0, l2r=0, maxd=0;
387
+ for (int64_t i = 0; i < S*D; i++) {
388
+ float a = manual[i], b = bf16_to_float(ref[i]);
389
+ l2d += (a-b)*(a-b); l2r += b*b;
390
+ if (std::abs(a-b) > maxd) maxd = std::abs(a-b);
391
+ }
392
+ double rel_manual = std::sqrt(l2d) / (std::sqrt(l2r) + 1e-10);
393
+ printf(" [cmp] MANUAL(row_idx=src→flat) rel=%.4e max_abs=%.4f m[:4]=%.5f %.5f %.5f %.5f r[:4]=%.5f %.5f %.5f %.5f\n",
394
+ rel_manual, maxd,
395
+ manual[0], manual[1], manual[2], manual[3],
396
+ bf16_to_float(ref[0]), bf16_to_float(ref[1]), bf16_to_float(ref[2]), bf16_to_float(ref[3]));
397
+
398
+ // Alternative semantic: row_idx[p] = destination position
399
+ // In that case: p=src_row, dst=h_ri[p]
400
+ std::vector<float> manual2(S * D, 0.0f);
401
+ for (int64_t p = 0; p < TOTAL; p++) {
402
+ int32_t dst = h_ri[p];
403
+ int s = dst / K;
404
+ int k = dst % K;
405
+ if (s < 0 || s >= S || k < 0 || k >= K) continue;
406
+ float w = bf16_to_float(h_tw[s * K + k]);
407
+ for (int d = 0; d < D; d++) {
408
+ manual2[s * D + d] += w * bf16_to_float(h_down[p * D + d]);
409
+ }
410
+ }
411
+ double l2d2=0, l2r2=0, maxd2=0;
412
+ for (int64_t i = 0; i < S*D; i++) {
413
+ float a = manual2[i], b = bf16_to_float(ref[i]);
414
+ l2d2 += (a-b)*(a-b); l2r2 += b*b;
415
+ if (std::abs(a-b) > maxd2) maxd2 = std::abs(a-b);
416
+ }
417
+ double rel_manual2 = std::sqrt(l2d2) / (std::sqrt(l2r2) + 1e-10);
418
+ printf(" [cmp] MANUAL(row_idx=p→dst_flat) rel=%.4e max_abs=%.4f m[:4]=%.5f %.5f %.5f %.5f\n",
419
+ rel_manual2, maxd2,
420
+ manual2[0], manual2[1], manual2[2], manual2[3]);
421
+ }
422
+
423
+ // Manual finalize using cumsum (semantics-independent):
424
+ // For each (n, k), find p such that actual_s(p)=n AND expert(p)=topk_idx[n,k], then
425
+ // out[n] += topk_w[n,k] * down_out[p].
426
+ {
427
+ std::vector<uint16_t> h_down(TOTAL * D);
428
+ std::vector<int64_t> h_tpe(E);
429
+ std::vector<int32_t> h_tidx(S * K);
430
+ std::vector<uint16_t> h_tw(S * K);
431
+ std::vector<uint16_t> h_xn_all(S * D);
432
+ std::vector<uint16_t> h_ex_all(TOTAL * D);
433
+ ACL_CHECK(aclrtMemcpy(h_down.data(), TOTAL*D*2, down_out_dev.get(), TOTAL*D*2, ACL_MEMCPY_DEVICE_TO_HOST));
434
+ ACL_CHECK(aclrtMemcpy(h_tpe.data(), E*8, tokens_per_expert_dev.get(), E*8, ACL_MEMCPY_DEVICE_TO_HOST));
435
+ ACL_CHECK(aclrtMemcpy(h_tidx.data(), S*K*4, topk_idx_dev.get(), S*K*4, ACL_MEMCPY_DEVICE_TO_HOST));
436
+ ACL_CHECK(aclrtMemcpy(h_tw.data(), S*K*2, topk_w_dev.get(), S*K*2, ACL_MEMCPY_DEVICE_TO_HOST));
437
+ ACL_CHECK(aclrtMemcpy(h_xn_all.data(), S*D*2, xn_dev.get(), S*D*2, ACL_MEMCPY_DEVICE_TO_HOST));
438
+ ACL_CHECK(aclrtMemcpy(h_ex_all.data(), TOTAL*D*2, expanded_x_dev.get(), TOTAL*D*2, ACL_MEMCPY_DEVICE_TO_HOST));
439
+
440
+ // Build p → (actual_s, actual_expert).
441
+ // actual_s: find s with xn[s,0] == expanded_x[p,0]
442
+ // actual_expert: find e such that cumsum_tpe[e-1] <= p < cumsum_tpe[e]
443
+ std::vector<int> p_to_s(TOTAL), p_to_e(TOTAL);
444
+ int64_t cum = 0;
445
+ int cursor_e = 0;
446
+ for (int64_t p = 0; p < TOTAL; p++) {
447
+ while (cursor_e < E && p >= cum + h_tpe[cursor_e]) { cum += h_tpe[cursor_e]; cursor_e++; }
448
+ p_to_e[p] = cursor_e;
449
+ float ev = bf16_to_float(h_ex_all[p * D]);
450
+ int best = -1; float bd = 1e30f;
451
+ for (int s = 0; s < S; s++) {
452
+ float df = std::abs(bf16_to_float(h_xn_all[s * D]) - ev);
453
+ if (df < bd) { bd = df; best = s; }
454
+ }
455
+ p_to_s[p] = best;
456
+ }
457
+
458
+ // Build (n, k) → p lookup via (n, expert) → p
459
+ std::vector<float> manual_cum(S * D, 0.0f);
460
+ int found_count = 0;
461
+ for (int n = 0; n < S; n++) {
462
+ for (int k = 0; k < K; k++) {
463
+ int e = h_tidx[n * K + k];
464
+ float w = bf16_to_float(h_tw[n * K + k]);
465
+ // search p with p_to_s[p]==n and p_to_e[p]==e
466
+ int found_p = -1;
467
+ for (int64_t p = 0; p < TOTAL; p++) {
468
+ if (p_to_s[p] == n && p_to_e[p] == e) { found_p = p; break; }
469
+ }
470
+ if (found_p < 0) {
471
+ printf(" [!!!] not found: n=%d k=%d expert=%d\n", n, k, e);
472
+ continue;
473
+ }
474
+ found_count++;
475
+ for (int d = 0; d < D; d++)
476
+ manual_cum[n * D + d] += w * bf16_to_float(h_down[found_p * D + d]);
477
+ }
478
+ }
479
+ auto refbuf = read_file(data_dir + "/out_flat.bin");
480
+ auto* ref = (const uint16_t*)refbuf.data();
481
+ double l2d=0, l2r=0, maxd=0;
482
+ for (int64_t i = 0; i < S*D; i++) {
483
+ float a = manual_cum[i], b = bf16_to_float(ref[i]);
484
+ l2d += (a-b)*(a-b); l2r += b*b;
485
+ if (std::abs(a-b) > maxd) maxd = std::abs(a-b);
486
+ }
487
+ double rel_cum = std::sqrt(l2d) / (std::sqrt(l2r) + 1e-10);
488
+ printf(" [cmp] MANUAL_CUMSUM (p via expert cumsum) rel=%.4e max=%.4f found=%d/40 m[:4]=%.5f %.5f %.5f %.5f\n",
489
+ rel_cum, maxd, found_count, manual_cum[0], manual_cum[1], manual_cum[2], manual_cum[3]);
490
+ }
491
+
492
+ // Dump all expanded_x[p, 0] and all xn[s, 0] to determine the mapping.
493
+ {
494
+ std::vector<uint16_t> h_xn_all(S * D);
495
+ ACL_CHECK(aclrtMemcpy(h_xn_all.data(), S*D*2, xn_dev.get(), S*D*2, ACL_MEMCPY_DEVICE_TO_HOST));
496
+ std::vector<uint16_t> h_ex_all(TOTAL * D);
497
+ ACL_CHECK(aclrtMemcpy(h_ex_all.data(), TOTAL*D*2, expanded_x_dev.get(), TOTAL*D*2, ACL_MEMCPY_DEVICE_TO_HOST));
498
+ printf(" xn[s, 0]: ");
499
+ for (int s = 0; s < S; s++) printf("%.5f ", bf16_to_float(h_xn_all[s * D]));
500
+ printf("\n expanded_x[p, 0]: ");
501
+ for (int p = 0; p < TOTAL; p++) printf("%.5f ", bf16_to_float(h_ex_all[p * D]));
502
+ printf("\n mapping p→s (by matching expanded_x[p,0] to xn[s,0]): ");
503
+ for (int p = 0; p < TOTAL; p++) {
504
+ float e = bf16_to_float(h_ex_all[p * D]);
505
+ int match = -1; float best = 1e30f;
506
+ for (int s = 0; s < S; s++) {
507
+ float df = std::abs(bf16_to_float(h_xn_all[s * D]) - e);
508
+ if (df < best) { best = df; match = s; }
509
+ }
510
+ printf("%d ", match);
511
+ }
512
+ printf("\n");
513
+ }
514
+
515
+ // Dump gate_out[p=4, :8] — gate activation of xn[0] via expert 10
516
+ {
517
+ std::vector<uint16_t> h_gate(I);
518
+ // NOTE: gate_out_dev was overwritten by silu+mul. So we need to reload from scratch.
519
+ // Instead just show down_out[4, :4].
520
+ std::vector<uint16_t> h_d(D);
521
+ ACL_CHECK(aclrtMemcpy(h_d.data(), D*2, (char*)down_out_dev.get() + 4*D*2, D*2, ACL_MEMCPY_DEVICE_TO_HOST));
522
+ printf(" down_out[p=4, :4] (s=0, k=0, expert=10): %.5f %.5f %.5f %.5f\n",
523
+ bf16_to_float(h_d[0]), bf16_to_float(h_d[1]), bf16_to_float(h_d[2]), bf16_to_float(h_d[3]));
524
+ // If GMM is correct, down_out[4] ~ ref[0] / topk_w[0,0]. ref[0,:4]=[-0.025, -0.007, 0.005, -0.008] / 0.224 ~ [-0.113, -0.031, 0.024, -0.036].
525
+ // But it's just ONE contribution so hard to compare directly.
526
+ }
527
+
528
+ // Single-expert verification using linear_hf: compute gate/up/down for (xn[0], expert=10)
529
+ // and compare with GMM's down_out at the corresponding position.
530
+ // linear_hf expects HF-layout weight [out_features, in_features]; our stacked gate_exps/up_exps
531
+ // are [E, D, I] — meaning per-expert shape is [D, I] (K, N) NOT HF [I, D]. So we can NOT directly
532
+ // linear_hf from gate_exps. Instead, load the expert-10 weight fresh and use linear_hf.
533
+ {
534
+ std::vector<int32_t> h_tidx_local(S * K);
535
+ ACL_CHECK(aclrtMemcpy(h_tidx_local.data(), S*K*4, topk_idx_dev.get(), S*K*4, ACL_MEMCPY_DEVICE_TO_HOST));
536
+ int target_expert = h_tidx_local[0 * K + 0]; // topk_idx[0, 0] should be 10 from Python ref
537
+ printf("\n === Single-expert linear_hf vs GMM sanity (token 0, expert %d) ===\n", target_expert);
538
+
539
+ // Recompute p_to_s and p_to_e from host data (scoped locally).
540
+ std::vector<int64_t> h_tpe2(E);
541
+ std::vector<uint16_t> h_xn_all2(S * D);
542
+ std::vector<uint16_t> h_ex_all2(TOTAL * D);
543
+ ACL_CHECK(aclrtMemcpy(h_tpe2.data(), E*8, tokens_per_expert_dev.get(), E*8, ACL_MEMCPY_DEVICE_TO_HOST));
544
+ ACL_CHECK(aclrtMemcpy(h_xn_all2.data(), S*D*2, xn_dev.get(), S*D*2, ACL_MEMCPY_DEVICE_TO_HOST));
545
+ ACL_CHECK(aclrtMemcpy(h_ex_all2.data(), TOTAL*D*2, expanded_x_dev.get(), TOTAL*D*2, ACL_MEMCPY_DEVICE_TO_HOST));
546
+ std::vector<int> p_to_s(TOTAL), p_to_e(TOTAL);
547
+ {
548
+ int64_t cum = 0; int ce = 0;
549
+ for (int64_t p = 0; p < TOTAL; p++) {
550
+ while (ce < E && p >= cum + h_tpe2[ce]) { cum += h_tpe2[ce]; ce++; }
551
+ p_to_e[p] = ce;
552
+ float ev = bf16_to_float(h_ex_all2[p * D]);
553
+ int best = -1; float bd = 1e30f;
554
+ for (int s = 0; s < S; s++) {
555
+ float df = std::abs(bf16_to_float(h_xn_all2[s * D]) - ev);
556
+ if (df < bd) { bd = df; best = s; }
557
+ }
558
+ p_to_s[p] = best;
559
+ }
560
+ }
561
+
562
+ DeviceBuffer g_w, u_w, d_w;
563
+ char ename[256];
564
+ snprintf(ename, sizeof(ename), "model.layers.0.mlp.experts.%d.gate_proj.weight", target_expert);
565
+ if (!dw.st().get(ename)) { printf(" missing %s\n", ename); goto after_sanity; }
566
+
567
+ // Load full per-expert weight using public helpers (indirectly via loader).
568
+ // Easiest: use load_tensor_full_ via friend access... Instead, use st_ directly.
569
+ {
570
+ auto* m_gate = dw.st().get(ename);
571
+ DeviceBuffer gw_buf(m_gate->nbytes);
572
+ ACL_CHECK(aclrtMemcpy(gw_buf.get(), m_gate->nbytes, dw.st().data_ptr(*m_gate), m_gate->nbytes, ACL_MEMCPY_HOST_TO_DEVICE));
573
+ g_w = std::move(gw_buf);
574
+
575
+ snprintf(ename, sizeof(ename), "model.layers.0.mlp.experts.%d.up_proj.weight", target_expert);
576
+ auto* m_up = dw.st().get(ename);
577
+ DeviceBuffer uw_buf(m_up->nbytes);
578
+ ACL_CHECK(aclrtMemcpy(uw_buf.get(), m_up->nbytes, dw.st().data_ptr(*m_up), m_up->nbytes, ACL_MEMCPY_HOST_TO_DEVICE));
579
+ u_w = std::move(uw_buf);
580
+
581
+ snprintf(ename, sizeof(ename), "model.layers.0.mlp.experts.%d.down_proj.weight", target_expert);
582
+ auto* m_down = dw.st().get(ename);
583
+ DeviceBuffer dw_buf(m_down->nbytes);
584
+ ACL_CHECK(aclrtMemcpy(dw_buf.get(), m_down->nbytes, dw.st().data_ptr(*m_down), m_down->nbytes, ACL_MEMCPY_HOST_TO_DEVICE));
585
+ d_w = std::move(dw_buf);
586
+ }
587
+
588
+ // Compute gate = xn[0] @ gate_w.T → [I]; up = xn[0] @ up_w.T → [I]; act; down = act @ down_w.T → [D]
589
+ DeviceBuffer xn0_dev(D * 2);
590
+ ACL_CHECK(aclrtMemcpy(xn0_dev.get(), D*2, xn_dev.get(), D*2, ACL_MEMCPY_DEVICE_TO_DEVICE));
591
+
592
+ DeviceBuffer gate_v(I * 2), up_v(I * 2), act_v(I * 2), down_v(D * 2);
593
+ auto t_xn0 = make_contig_tensor(xn0_dev.get(), ACL_BF16, {1, D});
594
+ auto t_gate = make_contig_tensor(gate_v.get(), ACL_BF16, {1, I});
595
+ auto t_up = make_contig_tensor(up_v.get(), ACL_BF16, {1, I});
596
+ auto t_act = make_contig_tensor(act_v.get(), ACL_BF16, {1, I});
597
+ auto t_down = make_contig_tensor(down_v.get(), ACL_BF16, {1, D});
598
+ linear_hf(rt.stream(), t_xn0.get(), g_w.get(), ACL_BF16, I, D, t_gate.get()); // gate_proj HF [I, D]
599
+ linear_hf(rt.stream(), t_xn0.get(), u_w.get(), ACL_BF16, I, D, t_up.get());
600
+ rt.sync();
601
+ silu(rt.stream(), t_gate.get(), t_act.get());
602
+ mul(rt.stream(), t_act.get(), t_up.get(), t_act.get());
603
+ rt.sync();
604
+ linear_hf(rt.stream(), t_act.get(), d_w.get(), ACL_BF16, D, I, t_down.get()); // down_proj HF [D, I]
605
+ rt.sync();
606
+
607
+ std::vector<uint16_t> h_down_lin(D);
608
+ ACL_CHECK(aclrtMemcpy(h_down_lin.data(), D*2, down_v.get(), D*2, ACL_MEMCPY_DEVICE_TO_HOST));
609
+
610
+ // Find the p in GMM output that corresponds to (s=0, expert=target_expert)
611
+ int found_p = -1;
612
+ for (int64_t p = 0; p < TOTAL; p++) {
613
+ if (p_to_s[p] == 0 && p_to_e[p] == target_expert) { found_p = p; break; }
614
+ }
615
+ if (found_p >= 0) {
616
+ std::vector<uint16_t> h_down_gmm(D);
617
+ ACL_CHECK(aclrtMemcpy(h_down_gmm.data(), D*2, (char*)down_out_dev.get() + found_p*D*2, D*2, ACL_MEMCPY_DEVICE_TO_HOST));
618
+ double l2d=0, l2r=0, maxd=0;
619
+ for (int i = 0; i < D; i++) {
620
+ float a = bf16_to_float(h_down_gmm[i]), b = bf16_to_float(h_down_lin[i]);
621
+ l2d += (a-b)*(a-b); l2r += b*b;
622
+ if (std::abs(a-b) > maxd) maxd = std::abs(a-b);
623
+ }
624
+ double rel = std::sqrt(l2d) / (std::sqrt(l2r) + 1e-10);
625
+ printf(" GMM down_out[p=%d] vs linear_hf down: rel=%.4e max=%.4f\n", found_p, rel, maxd);
626
+ printf(" GMM[:4]: %.5f %.5f %.5f %.5f\n",
627
+ bf16_to_float(h_down_gmm[0]), bf16_to_float(h_down_gmm[1]), bf16_to_float(h_down_gmm[2]), bf16_to_float(h_down_gmm[3]));
628
+ printf(" linear[:4]: %.5f %.5f %.5f %.5f\n",
629
+ bf16_to_float(h_down_lin[0]), bf16_to_float(h_down_lin[1]), bf16_to_float(h_down_lin[2]), bf16_to_float(h_down_lin[3]));
630
+ } else {
631
+ printf(" not found p for (s=0, expert=%d)\n", target_expert);
632
+ }
633
+ }
634
+ after_sanity:;
635
+
636
+ // Direct verification: gate_exps[expert_10, :4, :4] vs HF gate_proj_10 (transposed).
637
+ {
638
+ int expert_id = 10;
639
+ std::vector<uint16_t> h_stacked(4 * 4);
640
+ // gate_exps shape [E, D, I]. Expert 10 starts at offset expert_id * D * I * 2.
641
+ // Read the first 4 rows (d=0..3), first 4 cols (i=0..3). Row stride = I * 2 bytes.
642
+ for (int d = 0; d < 4; d++) {
643
+ ACL_CHECK(aclrtMemcpy(h_stacked.data() + d*4, 8,
644
+ (char*)moe.gate_exps.get() + (expert_id * D * I + d * I) * 2, 8,
645
+ ACL_MEMCPY_DEVICE_TO_HOST));
646
+ }
647
+ char ename[256];
648
+ snprintf(ename, sizeof(ename), "model.layers.0.mlp.experts.%d.gate_proj.weight", expert_id);
649
+ auto* m = dw.st().get(ename);
650
+ // HF gate_proj [I, D] row-major. Element at (i, d) is at offset (i*D + d)*2.
651
+ // Expected gate_exps[10, d, i] == HF_gate_proj[10][i, d].
652
+ // So for d in 0..3, i in 0..3: expected is HF[i, d].
653
+ std::vector<uint16_t> h_expected(4 * 4);
654
+ auto* hf = (const uint16_t*)dw.st().data_ptr(*m);
655
+ for (int d = 0; d < 4; d++) {
656
+ for (int i = 0; i < 4; i++) {
657
+ h_expected[d*4 + i] = hf[i * D + d]; // HF[i, d]
658
+ }
659
+ }
660
+ printf("\n === gate_exps[10, :4, :4] layout check ===\n");
661
+ printf(" stacked: ");
662
+ for (int i = 0; i < 16; i++) printf("%.5f ", bf16_to_float(h_stacked[i]));
663
+ printf("\n expected: ");
664
+ for (int i = 0; i < 16; i++) printf("%.5f ", bf16_to_float(h_expected[i]));
665
+ printf("\n");
666
+ int mism = 0;
667
+ for (int i = 0; i < 16; i++) if (h_stacked[i] != h_expected[i]) mism++;
668
+ printf(" mismatches: %d / 16\n", mism);
669
+ }
670
+
671
+ printf("\n=== Final (with residual) ===\n");
672
+ double rel = compare_bf16("final_out", final_dev.get(), S * D, "final_out.bin");
673
+ bool pass = rel < 5e-2;
674
+ printf("\n%s\n", pass ? "=== test_moe_layer PASS ===" : "=== test_moe_layer FAIL ===");
675
+ return pass ? 0 : 1;
676
+ }
tests/test_op_support.cpp ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // test_op_support.cpp — smoke test which aclnn ops actually RUN on 910 初代.
2
+ // Just call each candidate op with small tensors; report SUCCESS/FAILURE.
3
+ // Guides optimization feasibility analysis.
4
+ #include "acl_common.h"
5
+ #include "acl_runtime.h"
6
+ #include "aclnn_ops.h"
7
+ #include <acl/acl.h>
8
+ #include <aclnnop/aclnn_add_rms_norm.h>
9
+ #include <aclnnop/aclnn_npu_format_cast.h>
10
+ #include <aclnnop/aclnn_matmul.h>
11
+
12
+ #include <cstdio>
13
+ #include <cstring>
14
+ #include <vector>
15
+
16
+ static float bf16_to_float(uint16_t x) { uint32_t u = (uint32_t)x << 16; float f; std::memcpy(&f, &u, 4); return f; }
17
+ static uint16_t f_to_bf16(float f) { uint32_t u; std::memcpy(&u, &f, 4); return (uint16_t)((u + 0x7FFF + ((u >> 16) & 1)) >> 16); }
18
+
19
+ static const char* test_add_rms_norm(AclRuntime& rt) {
20
+ // Inputs: x1 [1, 16], x2 [1, 16] BF16; gamma [16] BF16
21
+ const int64_t D = 16;
22
+ std::vector<uint16_t> h_x1(D, f_to_bf16(0.5f));
23
+ std::vector<uint16_t> h_x2(D, f_to_bf16(0.3f));
24
+ std::vector<uint16_t> h_gamma(D, f_to_bf16(1.0f));
25
+ DeviceBuffer x1(D*2), x2(D*2), g(D*2), y(D*2), rstd(1*4), x_out(D*2);
26
+ ACL_CHECK(aclrtMemcpy(x1.get(), D*2, h_x1.data(), D*2, ACL_MEMCPY_HOST_TO_DEVICE));
27
+ ACL_CHECK(aclrtMemcpy(x2.get(), D*2, h_x2.data(), D*2, ACL_MEMCPY_HOST_TO_DEVICE));
28
+ ACL_CHECK(aclrtMemcpy(g.get(), D*2, h_gamma.data(), D*2, ACL_MEMCPY_HOST_TO_DEVICE));
29
+
30
+ auto tx1 = make_contig_tensor(x1.get(), ACL_BF16, {1, D});
31
+ auto tx2 = make_contig_tensor(x2.get(), ACL_BF16, {1, D});
32
+ auto tg = make_contig_tensor(g.get(), ACL_BF16, {D});
33
+ auto ty = make_contig_tensor(y.get(), ACL_BF16, {1, D});
34
+ auto trs = make_contig_tensor(rstd.get(), ACL_FLOAT, {1});
35
+ auto tout= make_contig_tensor(x_out.get(), ACL_BF16, {1, D});
36
+
37
+ uint64_t ws = 0; aclOpExecutor* exec = nullptr;
38
+ aclnnStatus s = aclnnAddRmsNormGetWorkspaceSize(tx1.get(), tx2.get(), tg.get(), 1e-6,
39
+ ty.get(), trs.get(), tout.get(), &ws, &exec);
40
+ if (s != 0) return "GetWorkspaceSize FAILED";
41
+ DeviceBuffer ws_buf;
42
+ if (ws > 0) ws_buf.alloc(ws);
43
+ s = aclnnAddRmsNorm(ws_buf.get(), ws, exec, rt.stream());
44
+ if (s != 0) return "aclnnAddRmsNorm FAILED (kernel not available on 910?)";
45
+ if (aclrtSynchronizeStream(rt.stream()) != 0) return "sync FAILED";
46
+ return "OK";
47
+ }
48
+
49
+ static const char* test_npu_format_cast_nz(AclRuntime& rt) {
50
+ // Transform a small [16, 16] BF16 tensor from ND to NZ format.
51
+ const int64_t H = 16, W = 16;
52
+ std::vector<uint16_t> h(H * W, f_to_bf16(1.0f));
53
+ DeviceBuffer src(H * W * 2);
54
+ ACL_CHECK(aclrtMemcpy(src.get(), H*W*2, h.data(), H*W*2, ACL_MEMCPY_HOST_TO_DEVICE));
55
+ auto tsrc = make_contig_tensor(src.get(), ACL_BF16, {H, W});
56
+
57
+ // Step 1: calculate NZ shape
58
+ int64_t* dst_shape = nullptr;
59
+ uint64_t dst_shape_size = 0;
60
+ int actual_fmt = 0;
61
+ aclnnStatus s = aclnnNpuFormatCastCalculateSizeAndFormat(
62
+ tsrc.get(), /*dstFormat=*/29 /* FRACTAL_NZ */,
63
+ /*additionalDtype=*/27 /* BF16 */,
64
+ &dst_shape, &dst_shape_size, &actual_fmt);
65
+ if (s != 0) return "CalculateSizeAndFormat FAILED";
66
+
67
+ // Step 2: alloc dst and call cast
68
+ int64_t total = 1;
69
+ std::vector<int64_t> shape_vec(dst_shape, dst_shape + dst_shape_size);
70
+ for (auto d : shape_vec) total *= d;
71
+ DeviceBuffer dst(total * 2);
72
+ auto tdst = make_acl_tensor(dst.get(), ACL_BF16, shape_vec, {}, (aclFormat)actual_fmt);
73
+
74
+ uint64_t ws = 0; aclOpExecutor* exec = nullptr;
75
+ s = aclnnNpuFormatCastGetWorkspaceSize(tsrc.get(), tdst.get(), &ws, &exec);
76
+ if (s != 0) return "FormatCast GetWorkspaceSize FAILED";
77
+ DeviceBuffer ws_buf; if (ws > 0) ws_buf.alloc(ws);
78
+ s = aclnnNpuFormatCast(ws_buf.get(), ws, exec, rt.stream());
79
+ if (s != 0) return "aclnnNpuFormatCast FAILED (NZ not supported on 910?)";
80
+ if (aclrtSynchronizeStream(rt.stream()) != 0) return "sync FAILED";
81
+ return "OK";
82
+ }
83
+
84
+ static const char* test_matmul_nz(AclRuntime& rt) {
85
+ // Try a MatMul with NZ-format weight.
86
+ const int64_t M = 16, K = 32, N = 16;
87
+ std::vector<uint16_t> h_x(M * K, f_to_bf16(0.1f));
88
+ std::vector<uint16_t> h_w(K * N, f_to_bf16(0.1f));
89
+ DeviceBuffer x(M*K*2), w(K*N*2), y(M*N*2);
90
+ ACL_CHECK(aclrtMemcpy(x.get(), M*K*2, h_x.data(), M*K*2, ACL_MEMCPY_HOST_TO_DEVICE));
91
+ ACL_CHECK(aclrtMemcpy(w.get(), K*N*2, h_w.data(), K*N*2, ACL_MEMCPY_HOST_TO_DEVICE));
92
+
93
+ auto tx = make_contig_tensor(x.get(), ACL_BF16, {M, K});
94
+ auto tw_nd = make_contig_tensor(w.get(), ACL_BF16, {K, N});
95
+
96
+ // Convert W to NZ
97
+ int64_t* dst_shape = nullptr; uint64_t dst_size = 0; int fmt = 0;
98
+ if (aclnnNpuFormatCastCalculateSizeAndFormat(tw_nd.get(), 29, 27, &dst_shape, &dst_size, &fmt) != 0)
99
+ return "calc NZ FAILED";
100
+ int64_t total = 1;
101
+ std::vector<int64_t> sh(dst_shape, dst_shape + dst_size);
102
+ for (auto d : sh) total *= d;
103
+ DeviceBuffer w_nz(total * 2);
104
+ auto tw_nz = make_acl_tensor(w_nz.get(), ACL_BF16, sh, {}, (aclFormat)fmt);
105
+
106
+ uint64_t ws = 0; aclOpExecutor* e = nullptr;
107
+ aclnnStatus s = aclnnNpuFormatCastGetWorkspaceSize(tw_nd.get(), tw_nz.get(), &ws, &e);
108
+ if (s != 0) return "NZ cast ws FAILED";
109
+ DeviceBuffer wb; if (ws > 0) wb.alloc(ws);
110
+ if (aclnnNpuFormatCast(wb.get(), ws, e, rt.stream()) != 0) return "NZ cast EXEC FAILED";
111
+ if (aclrtSynchronizeStream(rt.stream()) != 0) return "NZ cast sync FAILED";
112
+
113
+ // Now try MatMul with x (ND) × w_nz (NZ)
114
+ auto ty = make_contig_tensor(y.get(), ACL_BF16, {M, N});
115
+ ws = 0; e = nullptr;
116
+ s = aclnnMatmulGetWorkspaceSize(tx.get(), tw_nz.get(), ty.get(), 0 /*trans*/, &ws, &e);
117
+ if (s != 0) return "MatMul NZ GetWorkspaceSize FAILED";
118
+ DeviceBuffer mwb; if (ws > 0) mwb.alloc(ws);
119
+ if (aclnnMatmul(mwb.get(), ws, e, rt.stream()) != 0) return "MatMul NZ EXEC FAILED (MatMul doesn't accept NZ on 910?)";
120
+ if (aclrtSynchronizeStream(rt.stream()) != 0) return "MatMul NZ sync FAILED";
121
+ return "OK";
122
+ }
123
+
124
+ static const char* test_multi_stream(AclRuntime& rt) {
125
+ // Allocate a SECOND stream and check it works.
126
+ aclrtStream s2 = nullptr;
127
+ if (aclrtCreateStream(&s2) != 0) return "aclrtCreateStream FAILED";
128
+ // Simple dummy op on s2
129
+ DeviceBuffer x(16 * 2);
130
+ std::vector<uint16_t> hx(16, 0);
131
+ if (aclrtMemcpyAsync(x.get(), 16*2, hx.data(), 16*2, ACL_MEMCPY_HOST_TO_DEVICE, s2) != 0) return "memcpy on s2 FAILED";
132
+ if (aclrtSynchronizeStream(s2) != 0) return "sync s2 FAILED";
133
+ aclrtDestroyStream(s2);
134
+ return "OK";
135
+ }
136
+
137
+ int main() {
138
+ AclRuntime rt;
139
+ rt.init(0);
140
+
141
+ printf("=== 910 op support smoke test ===\n");
142
+
143
+ const char* r1 = test_add_rms_norm(rt);
144
+ printf(" aclnnAddRmsNorm (fused Add+RmsNorm): %s\n", r1);
145
+
146
+ const char* r2 = test_npu_format_cast_nz(rt);
147
+ printf(" aclnnNpuFormatCast → FRACTAL_NZ: %s\n", r2);
148
+
149
+ const char* r3 = test_matmul_nz(rt);
150
+ printf(" aclnnMatmul with NZ weight: %s\n", r3);
151
+
152
+ const char* r4 = test_multi_stream(rt);
153
+ printf(" Multi-stream (compute/comm overlap): %s\n", r4);
154
+
155
+ // More candidates
156
+ printf("\n=== Additional 910 op candidates ===\n");
157
+
158
+ // InplaceAddRmsNorm
159
+ #include <aclnnop/aclnn_inplace_add_rms_norm.h>
160
+ {
161
+ const int64_t D = 16;
162
+ std::vector<uint16_t> h(D, f_to_bf16(0.5f)), hg(D, f_to_bf16(1.0f));
163
+ DeviceBuffer x1(D*2), x2(D*2), g(D*2), rstd(4);
164
+ ACL_CHECK(aclrtMemcpy(x1.get(), D*2, h.data(), D*2, ACL_MEMCPY_HOST_TO_DEVICE));
165
+ ACL_CHECK(aclrtMemcpy(x2.get(), D*2, h.data(), D*2, ACL_MEMCPY_HOST_TO_DEVICE));
166
+ ACL_CHECK(aclrtMemcpy(g.get(), D*2, hg.data(), D*2, ACL_MEMCPY_HOST_TO_DEVICE));
167
+ auto tx1 = make_contig_tensor(x1.get(), ACL_BF16, {1, D});
168
+ auto tx2 = make_contig_tensor(x2.get(), ACL_BF16, {1, D});
169
+ auto tg = make_contig_tensor(g.get(), ACL_BF16, {D});
170
+ auto tr = make_contig_tensor(rstd.get(), ACL_FLOAT, {1});
171
+ uint64_t ws = 0; aclOpExecutor* e = nullptr;
172
+ aclnnStatus s = aclnnInplaceAddRmsNormGetWorkspaceSize(tx1.get(), tx2.get(), tg.get(), 1e-6,
173
+ tr.get(), &ws, &e);
174
+ printf(" aclnnInplaceAddRmsNorm: %s\n", s == 0 ? "GetWS OK" : "GetWS FAILED");
175
+ if (s == 0) {
176
+ DeviceBuffer wb; if (ws > 0) wb.alloc(ws);
177
+ s = aclnnInplaceAddRmsNorm(wb.get(), ws, e, rt.stream());
178
+ printf(" exec: %s\n", s == 0 ? "OK" : "FAILED");
179
+ }
180
+ }
181
+
182
+ // Test HCCL AllReduce on a separate stream
183
+ printf(" HCCL AllReduce on stream2: requires TP>1, skipped in this smoke test\n");
184
+
185
+ printf("\n=== FINAL Feasibility Summary ===\n");
186
+ printf(" Optimization A (FRACTAL_NZ): INFEASIBLE (910 不支持)\n");
187
+ printf(" Optimization B (multi-stream): FEASIBLE\n");
188
+ printf(" Optimization C (Add+RmsNorm): INFEASIBLE (910 无 kernel)\n");
189
+ return 0;
190
+ }