Entrit commited on
Commit
51e3123
·
verified ·
1 Parent(s): 79e1a6b

initial public release: code, README, KNOWN_ISSUES

Browse files
Files changed (6) hide show
  1. KNOWN_ISSUES.md +64 -0
  2. README.md +116 -0
  3. build.sh +52 -0
  4. trit_gemv.cu +292 -0
  5. trit_gemv_lib.py +280 -0
  6. trit_gemv_standalone.cu +598 -0
KNOWN_ISSUES.md ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Known issues — tritllm-kernel
2
+
3
+ Surfaced during a pre-release code review. None affect the published paper benchmark numbers (those were obtained on shapes that respect the contract), but anyone using these kernels with new shapes, custom launch parameters, or as a drop-in inference primitive should be aware.
4
+
5
+ ## BLOCKER — must respect or fix before relying on the kernel
6
+
7
+ ### 1. Implicit one-warp-per-block launch contract
8
+ **Where:** [`trit_gemv.cu:190-237` (`trit_gemv_uniform`)](trit_gemv.cu#L190), [`trit_gemv.cu:245-290` (`trit_gemv_variable`)](trit_gemv.cu#L245)
9
+
10
+ The kernels use `lane = threadIdx.x` directly as the lane index and reduce with a full-warp mask `__shfl_down_sync(0xFFFFFFFF, ...)`. This is correct only when `blockDim.x == 32`.
11
+
12
+ If launched with `blockDim.x > 32`:
13
+ - Threads with `threadIdx.x >= 32` will compute `idx = lane*2+i` past the 64-element group bound and read out-of-bounds.
14
+ - All threads with lane 0 across multiple warps race to write `y[row]`.
15
+
16
+ **Fix in caller:** always launch with `blockDim.x == 32`. The host-side wrappers in `trit_gemv_standalone.cu` do this correctly. Direct callers from custom code must respect it.
17
+
18
+ **Future fix in kernel:** add `assert(blockDim.x == WARP_SIZE)` at kernel entry, or rewrite to handle multi-warp blocks correctly.
19
+
20
+ ### 2. `in_features` not a multiple of `GROUP_SIZE` is silently dropped
21
+ **Where:** [`trit_gemv.cu:194`](trit_gemv.cu#L194), [`trit_gemv.cu:259`](trit_gemv.cu#L259)
22
+
23
+ ```cpp
24
+ int num_groups = in_features / GROUP_SIZE;
25
+ ```
26
+
27
+ Integer division truncates. If `in_features % 64 != 0`, the trailing partial group is silently skipped and that fragment of the dot product is missing from the output.
28
+
29
+ **Fix in caller:** pad the input weight matrix (and activations) with zero rows to the next multiple of 64 before quantizing. The codec output already does this for Qwen, Llama, and Mistral architectures, all of which have `hidden_dim` divisible by 64.
30
+
31
+ **Future fix in kernel:** add `assert(in_features % GROUP_SIZE == 0)` at kernel entry, or write a tail-handling path.
32
+
33
+ ## SHOULD-FIX
34
+
35
+ ### 3. C API performs no input validation
36
+ **Where:** `trit_gemv_standalone.cu`, all `extern "C"` functions
37
+
38
+ `trit_gemv_d2_fast`, `trit_gemv_d2_dp4a`, `trit_gemv_d3_native`, etc. accept null pointers, mismatched `rows`/`cols`/`num_groups`, and incorrectly packed buffers without complaint. Bad inputs become device faults or OOB reads.
39
+
40
+ For a public ctypes-facing library this is sharp. We will add a validation pass in a future revision; for now, callers must guarantee their arguments.
41
+
42
+ ### 4. `get_gpu_name(char* buf, int buflen)` has no null/length guard
43
+ **Where:** [`trit_gemv_standalone.cu:700`](trit_gemv_standalone.cu#L700)
44
+
45
+ Calling with `buf == nullptr` or `buflen <= 0` is immediate UB on the host side. Trivial fix; pending.
46
+
47
+ ### 5. CUDA error returns are not surfaced
48
+ **Where:** several places in `trit_gemv_standalone.cu` where `set_l2_persist`, kernel launches, and helper calls drop `cudaError_t` returns
49
+
50
+ If a kernel launch fails (e.g., bad shapes that pass the (missing) input validation), the failure is silent until the next `cudaDeviceSynchronize()` or `cudaGetLastError()`. The public functions return `void` and have no error-reporting path.
51
+
52
+ Workaround: call `cuda_sync()` after each operation and check `cudaGetLastError()` from your wrapper.
53
+
54
+ ### 6. Reduction wastes 31 lanes per group
55
+ **Where:** [`trit_gemv.cu:223-232`](trit_gemv.cu#L223), [`trit_gemv.cu:279-286`](trit_gemv.cu#L279)
56
+
57
+ After the warp reduction, only lane 0 multiplies by the group scale and accumulates into `row_acc`. The other 31 lanes idle for the scale/add path. This is correct, just leaves performance on the table relative to the deferred-reduction design used in `k_d3_hardened` (`trit_gemv_standalone.cu:493`).
58
+
59
+ The headline 7.8× number is from the deferred-reduction path, so this only matters if you use the educational `trit_gemv_uniform` / `trit_gemv_variable` kernels directly.
60
+
61
+ ## NIT
62
+
63
+ ### 7. Multiple prototype kernels in production file
64
+ `trit_gemv_standalone.cu` contains v9, v27, v28, v29, `k_d3_hardened`, plus the non-deferred kernels — a development history rather than a clean public surface. The `k_v29_pipeline` / `trit_pipeline` path was broken (passed nullptr for required arrays) and was removed in commit prior to this release. The remaining prototypes (`k_v27`, `k_v29`, `k_v28`) are still wired through public C functions; they work, but the API surface is wider than needed. A future revision will trim to one canonical entry per depth.
README.md ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ tags:
4
+ - cuda
5
+ - quantization
6
+ - ternary
7
+ - llm-inference
8
+ - kernel
9
+ ---
10
+
11
+ # tritllm-kernel
12
+
13
+ Multiply-free ternary GEMV CUDA kernel for the codec from
14
+ **"Balanced Ternary Post-Training Quantization for Large Language Models"** (Stentzel, 2026).
15
+
16
+ The headline number from the paper: **7.8× speedup** over cuBLAS FP16 GEMV on RTX 4090 in the memory-bound regime, projected to full-model token generation throughput from per-layer benchmarks.
17
+
18
+ > **These are kernel-only projections, not end-to-end serving throughput.** They exclude attention, KV cache, sampling, and tokenizer overhead. See Section 7 of the paper for methodology.
19
+
20
+ ## What it is
21
+
22
+ A standalone CUDA shared library (`libtrit_gemv.so` / `.dll`) callable via `ctypes` from any language, with no PyTorch dependency. The same algorithm is also wrapped via PyTorch's pybind11 in `trit_gemv_wrapper.cu` for benchmarking.
23
+
24
+ The core trick: each ternary weight (-1, 0, +1) reduces a multiply-accumulate to a conditional add/subtract/skip. The kernel uses Ada/Hopper/Blackwell `dp4a` intrinsics on int4-packed weights and pre-interleaved int8 activations to do four ternary-times-int8 dot products per instruction.
25
+
26
+ ## Build
27
+
28
+ ```bash
29
+ cd kernel
30
+ ./build.sh
31
+ ```
32
+
33
+ The build script targets SM 70/75/80/86/89/90/100/120 in one fat binary so the `.so` runs on V100, T4, A100, RTX 30/40/50, H100, and B100/B200 without recompilation.
34
+
35
+ Required: `nvcc` (CUDA 11.8 or newer) and a C++ compiler.
36
+
37
+ ## Performance (Qwen2.5-7B, d=2, 3.47 bpw, 3.3 GB model)
38
+
39
+ | GPU | L2 cache | Tokens/sec | Speedup vs FP16 cuBLAS | Effective BW |
40
+ |---|---|---|---|---|
41
+ | RTX 4090 | 72 MB | 588 | 7.8× | 1940 GB/s |
42
+ | RTX 3090 | 6 MB | 192 | 3.4× | 633 GB/s |
43
+ | RTX 4080 Laptop | 64 MB | 133 | 5.8× | 439 GB/s |
44
+ | A100 80GB | 40 MB | 201 | 4.2× | 663 GB/s |
45
+
46
+ These are per-layer GEMV benchmarks projected to full-model token-generation throughput. The L2-cache size correlates strongly with speedup because each `d=2` layer fits in L2 on the RTX 4090, giving an effective bandwidth roughly 2× HBM bandwidth.
47
+
48
+ See `kernel/bench_*.py` for the benchmark drivers.
49
+
50
+ ## Launch contract
51
+
52
+ The kernels in `trit_gemv.cu` and `trit_gemv_standalone.cu` assume:
53
+
54
+ | Constraint | Why | What happens if violated |
55
+ |---|---|---|
56
+ | `blockDim.x == 32` (one warp per block) | Kernels use `__shfl_down_sync(0xFFFFFFFF, ...)` and lane-0 reduction | OOB index reads + race on `y[row]` |
57
+ | `in_features % 64 == 0` | Group size is fixed at 64 weights | Trailing partial group is silently dropped — incorrect output for that row |
58
+ | Weight, scale, and activation buffers are device-resident and properly aligned | Kernel uses `__ldg` for cached loads | UB / device fault |
59
+
60
+ If your model has `in_features` not divisible by 64, pad the weight matrix to the next multiple of 64 with zero rows before quantizing.
61
+
62
+ ## API surface
63
+
64
+ C ABI in `trit_gemv_standalone.cu`:
65
+
66
+ ```c
67
+ // Best-tested d=2 path (champion for 4090)
68
+ void trit_gemv_d2_fast(
69
+ const int32_t* pt, // [rows * num_groups * 8] int4-packed weights
70
+ const float* ws, // [rows * num_groups] scales
71
+ const int32_t* xt_e, // [num_groups * 8] even nibble activations
72
+ const int32_t* xt_o, // [num_groups * 8] odd nibble activations
73
+ const float* xs, // [num_groups] activation scales
74
+ float* y, // [rows] output
75
+ int cols, int rows, int num_groups,
76
+ int use_l2_persist // 0 = off, 1 = enable L2 persistence
77
+ );
78
+
79
+ // Native-trit packed d=3 (no int4 intermediate)
80
+ void trit_gemv_d3_native(
81
+ const int32_t* pt, // [rows * num_groups * 13] trit-packed
82
+ const float* sc,
83
+ const float* x,
84
+ float* y,
85
+ int cols, int rows, int depth
86
+ );
87
+
88
+ // L2 cache size query (for deciding whether to enable persist)
89
+ int get_l2_cache_bytes();
90
+ void get_gpu_name(char* buf, int buflen);
91
+ void cuda_sync();
92
+ ```
93
+
94
+ ## Known issues
95
+
96
+ Documented in [KNOWN_ISSUES.md](KNOWN_ISSUES.md). Summary:
97
+
98
+ - **Launch contract is implicit, not enforced.** Kernels are correct only with `blockDim.x == 32`. There are no runtime asserts; the contract is guarded only by the host-side wrappers in this file. Direct callers must respect it.
99
+ - **`in_features` not a multiple of 64 silently fails.** No assert. Pad your matrix.
100
+ - **C API has no input validation.** Null pointers, wrong dimensions, and buffer-shape mismatches become device faults or OOB reads. This is a public-API hardening item we have not yet completed.
101
+ - **CUDA error returns are not surfaced to the caller** in some helper paths. If a kernel launch fails, `cuda_sync()` will see it but the public functions return `void`.
102
+
103
+ ## Citation
104
+
105
+ ```
106
+ @article{stentzel2026ternaryptq,
107
+ title = {Balanced Ternary Post-Training Quantization for Large Language Models},
108
+ author = {Stentzel, Eric},
109
+ year = 2026,
110
+ note = {Entrit Systems}
111
+ }
112
+ ```
113
+
114
+ ## License
115
+
116
+ Apache-2.0.
build.sh ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # Build libtrit_gemv.so — standalone CUDA kernel library
3
+ # No PyTorch, no Python, no framework dependency.
4
+ # Just nvcc + CUDA runtime.
5
+ #
6
+ # Fat binary: compiles for all major GPU architectures.
7
+ # The right kernel is selected at runtime based on the GPU.
8
+
9
+ set -e
10
+ cd "$(dirname "$0")"
11
+
12
+ # Detect nvcc
13
+ NVCC=$(which nvcc 2>/dev/null || echo "/usr/local/cuda/bin/nvcc")
14
+ if [ ! -x "$NVCC" ]; then
15
+ echo "ERROR: nvcc not found. Install CUDA toolkit."
16
+ exit 1
17
+ fi
18
+
19
+ echo "Using nvcc: $NVCC"
20
+ $NVCC --version | head -1
21
+
22
+ # Architecture targets (fat binary)
23
+ # Volta (V100), Turing (2080), Ampere (3090/A100),
24
+ # Ada (4080/4090), Hopper (H100), Blackwell (5070+)
25
+ ARCHS=""
26
+ ARCHS="$ARCHS -gencode=arch=compute_70,code=sm_70" # V100
27
+ ARCHS="$ARCHS -gencode=arch=compute_75,code=sm_75" # 2080
28
+ ARCHS="$ARCHS -gencode=arch=compute_80,code=sm_80" # A100, 3080
29
+ ARCHS="$ARCHS -gencode=arch=compute_86,code=sm_86" # 3090
30
+ ARCHS="$ARCHS -gencode=arch=compute_89,code=sm_89" # 4080, 4090
31
+ ARCHS="$ARCHS -gencode=arch=compute_90,code=sm_90" # H100
32
+
33
+ # Blackwell — only if nvcc supports it (CUDA 12.8+)
34
+ if $NVCC --help 2>&1 | grep -q "compute_120"; then
35
+ ARCHS="$ARCHS -gencode=arch=compute_120,code=sm_120"
36
+ echo "Including Blackwell (sm_120)"
37
+ fi
38
+
39
+ echo "Building libtrit_gemv.so..."
40
+ $NVCC -O3 --use_fast_math \
41
+ -shared -Xcompiler -fPIC \
42
+ $ARCHS \
43
+ -o libtrit_gemv.so \
44
+ trit_gemv_standalone.cu
45
+
46
+ ls -la libtrit_gemv.so
47
+ echo "Done! Library ready at $(pwd)/libtrit_gemv.so"
48
+ echo ""
49
+ echo "Usage from Python:"
50
+ echo " from trit_gemv_lib import TritGEMV"
51
+ echo " lib = TritGEMV()"
52
+ echo " lib.gemv_d2(weights, scales, x_int8, x_scales, output, K, M, ng)"
trit_gemv.cu ADDED
@@ -0,0 +1,292 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * TritLLM CUDA Kernel — Ternary GEMV (Matrix-Vector Multiply)
3
+ *
4
+ * Core operation: y = W_ternary @ x
5
+ * Where W_ternary is packed ternary weights with per-group scales.
6
+ *
7
+ * Each group of 64 weights has:
8
+ * - A depth (1-4 trits per weight)
9
+ * - A FP16 scale factor
10
+ * - Packed trit values (2 bits per trit: 00=0, 01=+1, 10=-1, 11=unused)
11
+ *
12
+ * The key: NO floating-point multiply in the inner loop.
13
+ * Ternary MAC = conditional add/subtract.
14
+ */
15
+
16
+ #include <cuda_fp16.h>
17
+ #include <cuda_runtime.h>
18
+ #include <stdint.h>
19
+
20
+ #define GROUP_SIZE 64
21
+ #define WARP_SIZE 32
22
+
23
+ // Trit encoding: 2 bits per trit
24
+ // 00 = 0, 01 = +1, 10 = -1
25
+ #define TRIT_ZERO 0
26
+ #define TRIT_POS 1
27
+ #define TRIT_NEG 2
28
+
29
+ /*
30
+ * Depth 1 (3 levels: {-1, 0, +1}): 1 trit per weight, 2 bits per weight
31
+ * Pack 16 trits per uint32 (16 * 2 = 32 bits)
32
+ * Group of 64 = 4 uint32s
33
+ *
34
+ * Inner loop: read trit, branch-free conditional accumulate
35
+ */
36
+ __device__ __forceinline__ float trit_mac_d1(
37
+ const uint32_t* __restrict__ packed, // 4 uint32s = 64 trits
38
+ const float* __restrict__ x, // 64 activations
39
+ int lane // warp lane (0-31)
40
+ ) {
41
+ float acc = 0.0f;
42
+
43
+ // Each thread in warp handles 2 elements (64 / 32 = 2)
44
+ #pragma unroll
45
+ for (int i = 0; i < 2; i++) {
46
+ int idx = lane * 2 + i;
47
+ int word = idx / 16; // which uint32 (0-3)
48
+ int bit_offset = (idx % 16) * 2; // bit position within word
49
+
50
+ uint32_t trit = (packed[word] >> bit_offset) & 0x3;
51
+ float val = x[idx];
52
+
53
+ // Branch-free: acc += (trit == 1) * val - (trit == 2) * val
54
+ acc += ((trit == TRIT_POS) - (trit == TRIT_NEG)) * val;
55
+ }
56
+
57
+ return acc;
58
+ }
59
+
60
+ /*
61
+ * Depth 2 (9 levels: {-4..+4}): 2 trits per weight, 4 bits per weight
62
+ * Trit value = trit1 * 3 + trit0 - 4 (maps to -4..+4)
63
+ * Pack 8 values per uint32 (8 * 4 = 32 bits)
64
+ * Group of 64 = 8 uint32s
65
+ */
66
+ __device__ __forceinline__ float trit_mac_d2(
67
+ const uint32_t* __restrict__ packed, // 8 uint32s = 64 values
68
+ const float* __restrict__ x,
69
+ int lane
70
+ ) {
71
+ float acc = 0.0f;
72
+
73
+ #pragma unroll
74
+ for (int i = 0; i < 2; i++) {
75
+ int idx = lane * 2 + i;
76
+ int word = idx / 8;
77
+ int bit_offset = (idx % 8) * 4;
78
+
79
+ uint32_t bits = (packed[word] >> bit_offset) & 0xF;
80
+ // Decode: trit1 = bits >> 2, trit0 = bits & 0x3
81
+ // value = (trit1_sign * 3 + trit0_sign)
82
+ // where trit_sign: 00->0, 01->+1, 10->-1
83
+ int t0 = (int)(bits & 0x3);
84
+ int t1 = (int)((bits >> 2) & 0x3);
85
+ int sign0 = (t0 == TRIT_POS) - (t0 == TRIT_NEG);
86
+ int sign1 = (t1 == TRIT_POS) - (t1 == TRIT_NEG);
87
+ int level = sign1 * 3 + sign0; // -4 to +4
88
+
89
+ // Still no FP multiply — integer * float is one instruction
90
+ // level is small integer, compiler optimizes to repeated add
91
+ acc += level * x[idx];
92
+ }
93
+
94
+ return acc;
95
+ }
96
+
97
+ /*
98
+ * Depth 3 (27 levels: {-13..+13}): 3 trits per weight, 6 bits per weight
99
+ * Pack 5 values per uint32 (5 * 6 = 30 bits, 2 wasted)
100
+ * Group of 64 = 13 uint32s (64 values, last uint32 has 4 values)
101
+ */
102
+ __device__ __forceinline__ float trit_mac_d3(
103
+ const uint32_t* __restrict__ packed, // 13 uint32s
104
+ const float* __restrict__ x,
105
+ int lane
106
+ ) {
107
+ float acc = 0.0f;
108
+
109
+ #pragma unroll
110
+ for (int i = 0; i < 2; i++) {
111
+ int idx = lane * 2 + i;
112
+ int word = idx / 5;
113
+ int pos = idx % 5;
114
+ int bit_offset = pos * 6;
115
+
116
+ uint32_t bits = (packed[word] >> bit_offset) & 0x3F;
117
+ int t0 = (int)(bits & 0x3);
118
+ int t1 = (int)((bits >> 2) & 0x3);
119
+ int t2 = (int)((bits >> 4) & 0x3);
120
+ int s0 = (t0 == TRIT_POS) - (t0 == TRIT_NEG);
121
+ int s1 = (t1 == TRIT_POS) - (t1 == TRIT_NEG);
122
+ int s2 = (t2 == TRIT_POS) - (t2 == TRIT_NEG);
123
+ int level = s2 * 9 + s1 * 3 + s0; // -13 to +13
124
+
125
+ acc += level * x[idx];
126
+ }
127
+
128
+ return acc;
129
+ }
130
+
131
+ /*
132
+ * Depth 4 (81 levels: {-40..+40}): 4 trits per weight, 8 bits per weight
133
+ * Pack 4 values per uint32 (4 * 8 = 32 bits, perfect)
134
+ * Group of 64 = 16 uint32s
135
+ */
136
+ __device__ __forceinline__ float trit_mac_d4(
137
+ const uint32_t* __restrict__ packed, // 16 uint32s
138
+ const float* __restrict__ x,
139
+ int lane
140
+ ) {
141
+ float acc = 0.0f;
142
+
143
+ #pragma unroll
144
+ for (int i = 0; i < 2; i++) {
145
+ int idx = lane * 2 + i;
146
+ int word = idx / 4;
147
+ int bit_offset = (idx % 4) * 8;
148
+
149
+ uint32_t bits = (packed[word] >> bit_offset) & 0xFF;
150
+ int t0 = (int)(bits & 0x3);
151
+ int t1 = (int)((bits >> 2) & 0x3);
152
+ int t2 = (int)((bits >> 4) & 0x3);
153
+ int t3 = (int)((bits >> 6) & 0x3);
154
+ int s0 = (t0 == TRIT_POS) - (t0 == TRIT_NEG);
155
+ int s1 = (t1 == TRIT_POS) - (t1 == TRIT_NEG);
156
+ int s2 = (t2 == TRIT_POS) - (t2 == TRIT_NEG);
157
+ int s3 = (t3 == TRIT_POS) - (t3 == TRIT_NEG);
158
+ int level = s3 * 27 + s2 * 9 + s1 * 3 + s0;
159
+
160
+ acc += level * x[idx];
161
+ }
162
+
163
+ return acc;
164
+ }
165
+
166
+ /*
167
+ * Main GEMV kernel: y[out_features] = W[out_features, in_features] @ x[in_features]
168
+ *
169
+ * W is stored as packed ternary groups:
170
+ * - packed_trits: variable-length packed trit data per group
171
+ * - scales: FP16 scale per group
172
+ * - depths: uint8 depth per group (1-4)
173
+ * - group_offsets: byte offset into packed_trits for each group
174
+ *
175
+ * One warp per output row, iterating over groups along the input dimension.
176
+ * Warp reduction gives the final dot product.
177
+ */
178
+
179
+ // Simplified version: uniform depth across all groups in a tensor
180
+ // (variable-depth version below)
181
+ __global__ void trit_gemv_uniform(
182
+ const uint32_t* __restrict__ packed_trits, // packed trit data
183
+ const float* __restrict__ scales, // [num_groups] FP16 stored as float
184
+ const float* __restrict__ x, // [in_features]
185
+ float* __restrict__ y, // [out_features]
186
+ int in_features,
187
+ int out_features,
188
+ int depth // uniform depth 1-4
189
+ ) {
190
+ int row = blockIdx.x; // one block per output row
191
+ if (row >= out_features) return;
192
+
193
+ int lane = threadIdx.x; // lane within warp (0-31)
194
+ int num_groups = in_features / GROUP_SIZE;
195
+
196
+ // Words per group depends on depth
197
+ int words_per_group;
198
+ switch (depth) {
199
+ case 1: words_per_group = 4; break; // 64 * 2 / 32
200
+ case 2: words_per_group = 8; break; // 64 * 4 / 32
201
+ case 3: words_per_group = 13; break; // ceil(64 * 6 / 32)
202
+ case 4: words_per_group = 16; break; // 64 * 8 / 32
203
+ default: words_per_group = 4; break;
204
+ }
205
+
206
+ float row_acc = 0.0f;
207
+
208
+ for (int g = 0; g < num_groups; g++) {
209
+ int group_offset = (row * num_groups + g) * words_per_group;
210
+ const uint32_t* group_data = &packed_trits[group_offset];
211
+ const float* group_x = &x[g * GROUP_SIZE];
212
+ float scale = scales[row * num_groups + g];
213
+
214
+ float group_acc;
215
+ switch (depth) {
216
+ case 1: group_acc = trit_mac_d1(group_data, group_x, lane); break;
217
+ case 2: group_acc = trit_mac_d2(group_data, group_x, lane); break;
218
+ case 3: group_acc = trit_mac_d3(group_data, group_x, lane); break;
219
+ case 4: group_acc = trit_mac_d4(group_data, group_x, lane); break;
220
+ default: group_acc = 0.0f; break;
221
+ }
222
+
223
+ // Warp reduction
224
+ #pragma unroll
225
+ for (int offset = 16; offset > 0; offset >>= 1) {
226
+ group_acc += __shfl_down_sync(0xFFFFFFFF, group_acc, offset);
227
+ }
228
+
229
+ // Lane 0 accumulates the scaled result
230
+ if (lane == 0) {
231
+ row_acc += group_acc * scale;
232
+ }
233
+ }
234
+
235
+ // Write output
236
+ if (lane == 0) {
237
+ y[row] = row_acc;
238
+ }
239
+ }
240
+
241
+ /*
242
+ * Variable-depth version: each group can have a different depth.
243
+ * Uses a depth map and offset table to handle mixed-depth tensors.
244
+ */
245
+ __global__ void trit_gemv_variable(
246
+ const uint32_t* __restrict__ packed_trits,
247
+ const float* __restrict__ scales,
248
+ const uint8_t* __restrict__ depth_map, // [num_groups_per_row] depth per group
249
+ const int* __restrict__ group_offsets, // [num_groups_per_row + 1] word offsets
250
+ const float* __restrict__ x,
251
+ float* __restrict__ y,
252
+ int in_features,
253
+ int out_features
254
+ ) {
255
+ int row = blockIdx.x;
256
+ if (row >= out_features) return;
257
+
258
+ int lane = threadIdx.x;
259
+ int num_groups = in_features / GROUP_SIZE;
260
+
261
+ float row_acc = 0.0f;
262
+
263
+ for (int g = 0; g < num_groups; g++) {
264
+ int depth = depth_map[g];
265
+ int word_offset = group_offsets[g] + row * group_offsets[num_groups]; // row stride
266
+ const uint32_t* group_data = &packed_trits[word_offset];
267
+ const float* group_x = &x[g * GROUP_SIZE];
268
+ float scale = scales[row * num_groups + g];
269
+
270
+ float group_acc;
271
+ switch (depth) {
272
+ case 1: group_acc = trit_mac_d1(group_data, group_x, lane); break;
273
+ case 2: group_acc = trit_mac_d2(group_data, group_x, lane); break;
274
+ case 3: group_acc = trit_mac_d3(group_data, group_x, lane); break;
275
+ case 4: group_acc = trit_mac_d4(group_data, group_x, lane); break;
276
+ default: group_acc = 0.0f; break;
277
+ }
278
+
279
+ #pragma unroll
280
+ for (int offset = 16; offset > 0; offset >>= 1) {
281
+ group_acc += __shfl_down_sync(0xFFFFFFFF, group_acc, offset);
282
+ }
283
+
284
+ if (lane == 0) {
285
+ row_acc += group_acc * scale;
286
+ }
287
+ }
288
+
289
+ if (lane == 0) {
290
+ y[row] = row_acc;
291
+ }
292
+ }
trit_gemv_lib.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Framework-agnostic trit GEMV library.
2
+
3
+ Loads the pre-compiled libtrit_gemv.so via ctypes.
4
+ Works with PyTorch, JAX, CuPy, or raw CUDA pointers.
5
+
6
+ Compile the library once:
7
+ cd kernel/
8
+ ./build.sh
9
+
10
+ Then use from any framework:
11
+ from trit_gemv_lib import TritGEMV
12
+ lib = TritGEMV()
13
+
14
+ # PyTorch
15
+ lib.gemv_d2(pt_tensor, ws_tensor, xt_tensor, xs_tensor, y_tensor, cols, rows, ng)
16
+
17
+ # Raw pointers (CuPy, JAX, etc.)
18
+ lib.gemv_d2_ptr(pt_ptr, ws_ptr, xt_ptr, xs_ptr, y_ptr, cols, rows, ng)
19
+ """
20
+ import ctypes
21
+ import os
22
+ import subprocess
23
+ import sys
24
+
25
+ # Find the library
26
+ _LIB_NAMES = ['libtrit_gemv.so', 'libtrit_gemv.dll', 'trit_gemv.so']
27
+ _SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
28
+
29
+
30
+ def _find_lib():
31
+ for name in _LIB_NAMES:
32
+ path = os.path.join(_SCRIPT_DIR, name)
33
+ if os.path.exists(path):
34
+ return path
35
+ return None
36
+
37
+
38
+ def _build_lib():
39
+ """Auto-compile if not found."""
40
+ build_script = os.path.join(_SCRIPT_DIR, 'build.sh')
41
+ if os.path.exists(build_script):
42
+ print("Building libtrit_gemv.so...", flush=True)
43
+ subprocess.run(['bash', build_script], cwd=_SCRIPT_DIR, check=True)
44
+ else:
45
+ # Inline build
46
+ cu_file = os.path.join(_SCRIPT_DIR, 'trit_gemv_standalone.cu')
47
+ out_file = os.path.join(_SCRIPT_DIR, 'libtrit_gemv.so')
48
+ if not os.path.exists(cu_file):
49
+ raise FileNotFoundError(f"Cannot find {cu_file}")
50
+
51
+ # Detect GPU architecture
52
+ try:
53
+ import torch
54
+ cc = torch.cuda.get_device_capability(0)
55
+ arch = f"compute_{cc[0]}{cc[1]}"
56
+ sm = f"sm_{cc[0]}{cc[1]}"
57
+ gencode = f"-gencode=arch={arch},code={sm}"
58
+ except:
59
+ # Default to common architectures
60
+ gencode = " ".join([
61
+ f"-gencode=arch=compute_{a},code=sm_{a}"
62
+ for a in ["70", "75", "80", "86", "89", "90"]
63
+ ])
64
+
65
+ cmd = f"nvcc -O3 --use_fast_math -shared -Xcompiler -fPIC {gencode} -o {out_file} {cu_file}"
66
+ print(f"Compiling: {cmd}", flush=True)
67
+ subprocess.run(cmd, shell=True, check=True)
68
+
69
+ return _find_lib()
70
+
71
+
72
+ class TritGEMV:
73
+ """Framework-agnostic trit GEMV kernel."""
74
+
75
+ def __init__(self, lib_path=None):
76
+ if lib_path is None:
77
+ lib_path = _find_lib()
78
+ if lib_path is None:
79
+ lib_path = _build_lib()
80
+ if lib_path is None:
81
+ raise RuntimeError("Cannot find or build libtrit_gemv.so")
82
+
83
+ self._lib = ctypes.CDLL(lib_path)
84
+
85
+ # Set up function signatures
86
+ # d2 dp4a (champion)
87
+ self._lib.trit_gemv_d2_dp4a.argtypes = [
88
+ ctypes.c_void_p, # pt (int32*)
89
+ ctypes.c_void_p, # ws (float*)
90
+ ctypes.c_void_p, # xt (int32*)
91
+ ctypes.c_void_p, # xs (float*)
92
+ ctypes.c_void_p, # y (float*)
93
+ ctypes.c_int, # cols
94
+ ctypes.c_int, # rows
95
+ ctypes.c_int, # num_groups
96
+ ctypes.c_int, # use_l2_persist
97
+ ]
98
+ self._lib.trit_gemv_d2_dp4a.restype = None
99
+
100
+ # d3 native trit
101
+ self._lib.trit_gemv_d3_native.argtypes = [
102
+ ctypes.c_void_p, # pt
103
+ ctypes.c_void_p, # sc
104
+ ctypes.c_void_p, # x
105
+ ctypes.c_void_p, # y
106
+ ctypes.c_int, # cols
107
+ ctypes.c_int, # rows
108
+ ctypes.c_int, # depth
109
+ ]
110
+ self._lib.trit_gemv_d3_native.restype = None
111
+
112
+ # d3 int8 dp4a (no decode, DRAM-bound path)
113
+ self._lib.trit_gemv_d3_int8_dp4a.argtypes = [
114
+ ctypes.c_void_p, # wt (int32*)
115
+ ctypes.c_void_p, # ws (float*)
116
+ ctypes.c_void_p, # xt (int32*)
117
+ ctypes.c_void_p, # xs (float*)
118
+ ctypes.c_void_p, # y (float*)
119
+ ctypes.c_int, # cols
120
+ ctypes.c_int, # rows
121
+ ctypes.c_int, # num_groups
122
+ ctypes.c_int, # use_l2_persist
123
+ ]
124
+ self._lib.trit_gemv_d3_int8_dp4a.restype = None
125
+
126
+ # Utility
127
+ self._lib.get_l2_cache_bytes.restype = ctypes.c_int
128
+ self._lib.cuda_sync.restype = None
129
+
130
+ buf = ctypes.create_string_buffer(256)
131
+ self._lib.get_gpu_name(buf, 256)
132
+ self.gpu_name = buf.value.decode()
133
+ self.l2_bytes = self._lib.get_l2_cache_bytes()
134
+
135
+ def sync(self):
136
+ self._lib.cuda_sync()
137
+
138
+ def _get_ptr(self, tensor):
139
+ """Extract GPU pointer from any framework's tensor."""
140
+ if hasattr(tensor, 'data_ptr'):
141
+ # PyTorch
142
+ return tensor.data_ptr()
143
+ elif hasattr(tensor, '__cuda_array_interface__'):
144
+ # CuPy, JAX, Numba
145
+ return tensor.__cuda_array_interface__['data'][0]
146
+ elif isinstance(tensor, int):
147
+ # Raw pointer
148
+ return tensor
149
+ else:
150
+ raise TypeError(f"Cannot extract GPU pointer from {type(tensor)}")
151
+
152
+ def gemv_d2(self, pt, ws, xt, xs, y, cols, rows, num_groups, l2_persist=True):
153
+ """D2 GEMV with int4 packing + dp4a.
154
+
155
+ Args:
156
+ pt: int32 tensor [rows * num_groups * 8] — int4 packed weights
157
+ ws: float32 tensor [rows * num_groups] — weight scales
158
+ xt: int32 tensor [num_groups * 16] — int8 packed activations
159
+ xs: float32 tensor [num_groups] — activation scales
160
+ y: float32 tensor [rows] — output (written in-place)
161
+ cols: input dimension (K)
162
+ rows: output dimension (M)
163
+ num_groups: K // 64
164
+ l2_persist: enable L2 cache persistence (default True)
165
+ """
166
+ self._lib.trit_gemv_d2_dp4a(
167
+ self._get_ptr(pt), self._get_ptr(ws),
168
+ self._get_ptr(xt), self._get_ptr(xs),
169
+ self._get_ptr(y), cols, rows, num_groups,
170
+ 1 if l2_persist else 0,
171
+ )
172
+
173
+ def gemv_adaptive(self, pt_int4, ws, xt, xs, y, cols, rows, num_groups,
174
+ pt_int8=None):
175
+ """Hardware-aware GEMV: auto-selects best kernel based on L2 cache.
176
+
177
+ If the int4 weight data fits in L2 → uses d2 int4 + dp4a (5x FP16)
178
+ If not → uses pre-expanded int8 + dp4a (2x FP16, no decode overhead)
179
+
180
+ Args:
181
+ pt_int4: int32 tensor — int4 packed weights (always stored, compact)
182
+ ws: weight scales
183
+ xt, xs: quantized activations
184
+ y: output
185
+ pt_int8: optional pre-expanded int8 weights for DRAM path.
186
+ If None and needed, expanded on-the-fly (one-time cost).
187
+ """
188
+ weight_bytes = rows * num_groups * 8 * 4 # int4: 8 words per group
189
+ l2_margin = self.l2_bytes * 0.75 # leave 25% for x, scales, other data
190
+
191
+ if weight_bytes < l2_margin:
192
+ # Fits in L2 → use compact int4, decode inline at L2 speed
193
+ self._lib.trit_gemv_d2_dp4a(
194
+ self._get_ptr(pt_int4), self._get_ptr(ws),
195
+ self._get_ptr(xt), self._get_ptr(xs),
196
+ self._get_ptr(y), cols, rows, num_groups, 1)
197
+ else:
198
+ # Doesn't fit L2 → use int8 for zero-decode DRAM speed
199
+ if pt_int8 is None:
200
+ raise ValueError(
201
+ f"Layer ({weight_bytes/1e6:.0f} MB) exceeds L2 ({self.l2_bytes/1e6:.0f} MB). "
202
+ f"Provide pre-expanded pt_int8 for DRAM path. "
203
+ f"Use TritGEMV.expand_int4_to_int8(pt_int4) at model load time."
204
+ )
205
+ self._lib.trit_gemv_d3_int8_dp4a(
206
+ self._get_ptr(pt_int8), self._get_ptr(ws),
207
+ self._get_ptr(xt), self._get_ptr(xs),
208
+ self._get_ptr(y), cols, rows, num_groups, 0)
209
+
210
+ @staticmethod
211
+ def expand_int4_to_int8(pt_int4, device='cuda'):
212
+ """Pre-expand int4 packed weights to int8 for DRAM-bound layers.
213
+
214
+ Called once at model load. Uses 2x more VRAM but eliminates decode overhead.
215
+ int4: 8 words per group → int8: 16 words per group
216
+
217
+ Args:
218
+ pt_int4: int32 tensor [n_groups * 8] — int4 packed
219
+ Returns:
220
+ int32 tensor [n_groups * 16] — int8 packed (dp4a compatible)
221
+ """
222
+ import torch
223
+ n_words = pt_int4.shape[0]
224
+ n_groups = n_words // 8
225
+
226
+ # Each int4 word has 8 nibbles → 8 int8 values → 2 int8x4 words
227
+ pt_int8 = torch.zeros(n_groups * 16, dtype=torch.int32, device=device)
228
+
229
+ # Expand on GPU (vectorized)
230
+ for g in range(n_groups):
231
+ for w in range(8):
232
+ word = pt_int4[g * 8 + w].item()
233
+ for nib in range(8):
234
+ val = (word >> (nib * 4)) & 0xF
235
+ if val & 0x8:
236
+ val = val | 0xFFFFFFF0 # sign extend
237
+ val = val & 0xFF
238
+ out_col = w * 8 + nib
239
+ out_word = out_col // 4
240
+ out_byte = out_col % 4
241
+ pt_int8[g * 16 + out_word] |= (val << (out_byte * 8))
242
+
243
+ return pt_int8
244
+
245
+ def gemv_d3(self, pt, sc, x, y, cols, rows, depth=3):
246
+ """D3 GEMV with native trit packing.
247
+
248
+ Args:
249
+ pt: int32 tensor [rows * ng * 13] — trit packed weights
250
+ sc: float32 tensor [rows * ng] — scales
251
+ x: float32 tensor [cols] — activations
252
+ y: float32 tensor [rows] — output
253
+ """
254
+ self._lib.trit_gemv_d3_native(
255
+ self._get_ptr(pt), self._get_ptr(sc),
256
+ self._get_ptr(x), self._get_ptr(y),
257
+ cols, rows, depth,
258
+ )
259
+
260
+ def gemv_d3_int8(self, wt, ws, xt, xs, y, cols, rows, num_groups, l2_persist=True):
261
+ """D3 GEMV with int8 level packing + dp4a (same quality as d3, dp4a speed).
262
+
263
+ Args:
264
+ wt: int32 tensor [rows * num_groups * 16] — int8 packed levels
265
+ ws: float32 tensor [rows * num_groups] — weight scales
266
+ xt: int32 tensor [num_groups * 16] — int8 packed activations
267
+ xs: float32 tensor [num_groups * 16] — per-word x scales
268
+ y: float32 tensor [rows] — output
269
+ """
270
+ if not hasattr(self._lib, 'trit_gemv_d3_int8_dp4a'):
271
+ raise RuntimeError("d3 int8 not in this build — rebuild libtrit_gemv.so")
272
+ self._lib.trit_gemv_d3_int8_dp4a(
273
+ self._get_ptr(wt), self._get_ptr(ws),
274
+ self._get_ptr(xt), self._get_ptr(xs),
275
+ self._get_ptr(y), cols, rows, num_groups,
276
+ 1 if l2_persist else 0,
277
+ )
278
+
279
+ def __repr__(self):
280
+ return f"TritGEMV(gpu='{self.gpu_name}', l2={self.l2_bytes/1e6:.0f}MB)"
trit_gemv_standalone.cu ADDED
@@ -0,0 +1,598 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Standalone trit GEMV kernel — no PyTorch dependency.
3
+ * Compiles with nvcc to a shared library (.so/.dll).
4
+ * Called from Python via ctypes, or from C/C++ directly.
5
+ *
6
+ * Compile:
7
+ * nvcc -O3 --use_fast_math -shared -Xcompiler -fPIC \
8
+ * -gencode=arch=compute_70,code=sm_70 \
9
+ * -gencode=arch=compute_75,code=sm_75 \
10
+ * -gencode=arch=compute_80,code=sm_80 \
11
+ * -gencode=arch=compute_86,code=sm_86 \
12
+ * -gencode=arch=compute_89,code=sm_89 \
13
+ * -gencode=arch=compute_90,code=sm_90 \
14
+ * -gencode=arch=compute_100,code=sm_100 \
15
+ * -gencode=arch=compute_120,code=sm_120 \
16
+ * -o libtrit_gemv.so trit_gemv_standalone.cu
17
+ *
18
+ * Supports: Volta(V100), Turing(2080), Ampere(3090/A100), Ada(4080/4090),
19
+ * Hopper(H100), Blackwell(5070/5090) — all in one binary.
20
+ *
21
+ * API: C functions with extern "C" — callable from any language.
22
+ */
23
+
24
+ #include <cuda_runtime.h>
25
+ #include <stdint.h>
26
+
27
+ #define GROUP_SIZE 64
28
+ #define WARP_SIZE 32
29
+ #define TRIT_POS 1
30
+ #define TRIT_NEG 2
31
+
32
+ // Forward declarations
33
+ static void set_l2_persist(void* ptr, size_t bytes);
34
+ static void clear_l2_persist();
35
+
36
+ // ============================================================
37
+ // V27: D2 int4-packed + dp4a + L2 persist (champion kernel)
38
+ // ============================================================
39
+
40
+ // ============================================================
41
+ // V28: Branchless interleaved nibble decode (7 instructions for 8 weights)
42
+ //
43
+ // The trick: extract even nibbles (0,2,4,6) and odd nibbles (1,3,5,7)
44
+ // as separate byte vectors using mask + shift. Sign-extend all 4 bytes
45
+ // simultaneously with XOR + SUB (zero branches).
46
+ //
47
+ // x activations are pre-interleaved to match: x_evens has values at
48
+ // positions 0,2,4,6 and x_odds has 1,3,5,7. Pre-interleave is done
49
+ // once at activation quantization time (negligible cost).
50
+ //
51
+ // Instructions per 8 weights:
52
+ // v27: 32 (loop with branches)
53
+ // v28: 14 (7 expand + 2 dp4a + 3 load + 2 scale)
54
+ // Balance BW shift: 3.4 → 5.6 TB/s on A100 → crosses into memory-bound!
55
+ // ============================================================
56
+
57
+ #define V28_RPB 16
58
+ #define V28_WPG 8
59
+ #define V28_BS (V28_RPB * WARP_SIZE)
60
+
61
+ __global__ void k_v28(
62
+ const uint32_t* __restrict__ pt, // int4 packed: 8 weights per uint32
63
+ const float* __restrict__ ws, // weight scales [rows * ng]
64
+ const uint32_t* __restrict__ xt_e, // x int8 EVEN positions [ng * 8]
65
+ const uint32_t* __restrict__ xt_o, // x int8 ODD positions [ng * 8]
66
+ const float* __restrict__ xs, // x scales [ng]
67
+ float* __restrict__ y,
68
+ int cols, int rows, int num_groups
69
+ ) {
70
+ int wid = threadIdx.x / WARP_SIZE;
71
+ int lane = threadIdx.x % WARP_SIZE;
72
+ int row = blockIdx.x * V28_RPB + wid;
73
+ if (row >= rows) return;
74
+
75
+ const uint32_t* row_w = &pt[row * num_groups * V28_WPG];
76
+ const float* row_ws = &ws[row * num_groups];
77
+
78
+ float acc = 0.0f;
79
+ int total_words = num_groups * V28_WPG;
80
+
81
+ for (int base = 0; base < total_words; base += WARP_SIZE) {
82
+ int w = base + lane;
83
+ if (w < total_words) {
84
+ // COALESCED load
85
+ uint32_t word = __ldg(&row_w[w]);
86
+ int g = w >> 3; // group index (shift, 1 cycle)
87
+ int word_in_group = w & 7; // word within group (mask, 1 cycle)
88
+
89
+ // === BRANCHLESS INT4→INT8 EXPANSION (7 instructions) ===
90
+ // Extract even nibbles (weights 0,2,4,6) into bytes
91
+ uint32_t evens = word & 0x0F0F0F0F; // AND (1 op)
92
+ evens = (evens ^ 0x08080808) - 0x08080808; // XOR+SUB (2 ops)
93
+
94
+ // Extract odd nibbles (weights 1,3,5,7) into bytes
95
+ uint32_t odds = (word >> 4) & 0x0F0F0F0F; // SHR+AND (2 ops)
96
+ odds = (odds ^ 0x08080808) - 0x08080808; // XOR+SUB (2 ops)
97
+ // Total: 7 instructions for 8 sign-extended int8 values
98
+
99
+ // dp4a against pre-interleaved x
100
+ // x_evens[g*8 + word_in_group] has activations at even positions
101
+ // x_odds[g*8 + word_in_group] has activations at odd positions
102
+ int x_idx = g * 8 + word_in_group;
103
+ uint32_t xe = __ldg(&xt_e[x_idx]);
104
+ uint32_t xo = __ldg(&xt_o[x_idx]);
105
+
106
+ int dp_e = __dp4a((int)evens, (int)xe, 0);
107
+ int dp_o = __dp4a((int)odds, (int)xo, 0);
108
+
109
+ float combined_scale = __ldg(&row_ws[g]) * __ldg(&xs[g]);
110
+ acc += (float)(dp_e + dp_o) * combined_scale;
111
+ }
112
+ }
113
+
114
+ #pragma unroll
115
+ for (int o = 16; o > 0; o >>= 1)
116
+ acc += __shfl_down_sync(0xFFFFFFFF, acc, o);
117
+
118
+ if (lane == 0) y[row] = acc;
119
+ }
120
+
121
+ // ============================================================
122
+ // V29: BIAS TRICK — zero sign extension, unsigned weights + correction
123
+ //
124
+ // Store weights as (level + 4) → range 0-8 (unsigned).
125
+ // Zero-extend nibbles: AND only, no XOR, no SUB.
126
+ // dp4a gives biased result. Subtract precomputed correction.
127
+ //
128
+ // Decode: 3 instructions (AND, SHR, AND) vs v28's 7
129
+ // Correction: 1 SUB + 1 LD per word (precomputed x_bias)
130
+ // Net: 10 instructions per word vs v28's 14
131
+ // ============================================================
132
+
133
+ #define V29_RPB 16
134
+ #define V29_WPG 8
135
+ #define V29_BS (V29_RPB * WARP_SIZE)
136
+
137
+ __global__ void k_v29(
138
+ const uint32_t* __restrict__ pt, // UNSIGNED int4: (level+4) packed, 0-8 per nibble
139
+ const float* __restrict__ ws, // weight scales [rows * ng]
140
+ const uint32_t* __restrict__ xt_e, // x int8 EVEN [ng * 8]
141
+ const uint32_t* __restrict__ xt_o, // x int8 ODD [ng * 8]
142
+ const int* __restrict__ x_bias, // precomputed 4×(sum of 8 x values) per word position [ng * 8]
143
+ const float* __restrict__ xs, // x scales [ng]
144
+ float* __restrict__ y,
145
+ int cols, int rows, int num_groups
146
+ ) {
147
+ int wid = threadIdx.x / WARP_SIZE;
148
+ int lane = threadIdx.x % WARP_SIZE;
149
+ int row = blockIdx.x * V29_RPB + wid;
150
+ if (row >= rows) return;
151
+
152
+ const uint32_t* row_w = &pt[row * num_groups * V29_WPG];
153
+ const float* row_ws = &ws[row * num_groups];
154
+
155
+ float acc = 0.0f;
156
+ int total_words = num_groups * V29_WPG;
157
+
158
+ for (int base = 0; base < total_words; base += WARP_SIZE) {
159
+ int w = base + lane;
160
+ if (w < total_words) {
161
+ uint32_t word = __ldg(&row_w[w]);
162
+ int g = w >> 3;
163
+ int wig = w & 7;
164
+
165
+ // BIAS DECODE: 3 instructions total (no XOR, no SUB)
166
+ uint32_t evens = word & 0x0F0F0F0F; // AND
167
+ uint32_t odds = (word >> 4) & 0x0F0F0F0F; // SHR + AND
168
+ // Values 0-8 are valid positive int8 — no sign extension needed
169
+
170
+ int x_idx = g * 8 + wig;
171
+ uint32_t xe = __ldg(&xt_e[x_idx]);
172
+ uint32_t xo = __ldg(&xt_o[x_idx]);
173
+
174
+ // dp4a: biased result (includes +4 per weight)
175
+ int dp = __dp4a((int)evens, (int)xe, 0)
176
+ + __dp4a((int)odds, (int)xo, 0);
177
+
178
+ // Subtract precomputed bias: 4 × sum of 8 x values for this word
179
+ int bias = __ldg(&x_bias[x_idx]);
180
+ dp -= bias;
181
+
182
+ float combined_scale = __ldg(&row_ws[g]) * __ldg(&xs[g]);
183
+ acc += (float)dp * combined_scale;
184
+ }
185
+ }
186
+
187
+ #pragma unroll
188
+ for (int o = 16; o > 0; o >>= 1)
189
+ acc += __shfl_down_sync(0xFFFFFFFF, acc, o);
190
+
191
+ if (lane == 0) y[row] = acc;
192
+ }
193
+
194
+ // v29 wrapper moved to extern "C" block below
195
+
196
+ // Keep v27 as fallback (doesn't need interleaved x)
197
+ __device__ __forceinline__ uint32_t extract_int4x4_to_int8x4(uint32_t word, int start) {
198
+ uint32_t result = 0;
199
+ #pragma unroll
200
+ for (int i = 0; i < 4; i++) {
201
+ int shift = (start + i) * 4;
202
+ int nibble = (word >> shift) & 0xF;
203
+ int val = (nibble & 0x8) ? (nibble | 0xFFFFFFF0) : nibble;
204
+ result |= ((uint32_t)(val & 0xFF)) << (i * 8);
205
+ }
206
+ return result;
207
+ }
208
+
209
+ #define V27_RPB 16
210
+ #define V27_WPG 8
211
+ #define V27_BS (V27_RPB * WARP_SIZE)
212
+
213
+ __global__ void k_v27(
214
+ const uint32_t* __restrict__ pt,
215
+ const float* __restrict__ ws,
216
+ const uint32_t* __restrict__ xt,
217
+ const float* __restrict__ xs,
218
+ float* __restrict__ y,
219
+ int cols, int rows, int num_groups
220
+ ) {
221
+ int wid = threadIdx.x / WARP_SIZE;
222
+ int lane = threadIdx.x % WARP_SIZE;
223
+ int row = blockIdx.x * V28_RPB + wid;
224
+ if (row >= rows) return;
225
+
226
+ const uint32_t* row_w = &pt[row * num_groups * V27_WPG];
227
+ const float* row_ws = &ws[row * num_groups];
228
+
229
+ float acc = 0.0f;
230
+ int total_words = num_groups * V27_WPG;
231
+
232
+ for (int base = 0; base < total_words; base += WARP_SIZE) {
233
+ int w = base + lane;
234
+ if (w < total_words) {
235
+ uint32_t word = __ldg(&row_w[w]);
236
+ int g = w >> 3;
237
+ int word_in_group = w & 7;
238
+
239
+ uint32_t lo = extract_int4x4_to_int8x4(word, 0);
240
+ uint32_t hi = extract_int4x4_to_int8x4(word, 4);
241
+
242
+ int x_base = g * 16 + word_in_group * 2;
243
+ uint32_t x_lo = __ldg(&xt[x_base]);
244
+ uint32_t x_hi = __ldg(&xt[x_base + 1]);
245
+
246
+ int dp_lo = __dp4a((int)lo, (int)x_lo, 0);
247
+ int dp_hi = __dp4a((int)hi, (int)x_hi, 0);
248
+
249
+ float combined_scale = __ldg(&row_ws[g]) * __ldg(&xs[g]);
250
+ acc += (float)(dp_lo + dp_hi) * combined_scale;
251
+ }
252
+ }
253
+
254
+ #pragma unroll
255
+ for (int o = 16; o > 0; o >>= 1)
256
+ acc += __shfl_down_sync(0xFFFFFFFF, acc, o);
257
+
258
+ if (lane == 0) y[row] = acc;
259
+ }
260
+
261
+ // ============================================================
262
+ // V9-style: trit-packed d3 (for models stored in trit format)
263
+ // ============================================================
264
+
265
+ #define V9R 4
266
+ #define V9W 2
267
+ #define V9BS (V9R * V9W * WARP_SIZE)
268
+
269
+ __device__ __forceinline__ float mac_wide_d3(
270
+ const uint32_t* __restrict__ p, const float* __restrict__ x, int tid
271
+ ) {
272
+ float acc = 0.0f;
273
+ #pragma unroll
274
+ for (int i = 0; i < 4; i++) {
275
+ int idx = tid * 4 + i;
276
+ int w = idx / 5, pos = idx % 5;
277
+ uint32_t bits = (__ldg(&p[w]) >> (pos * 6)) & 0x3F;
278
+ int t0 = bits & 3, t1 = (bits >> 2) & 3, t2 = (bits >> 4) & 3;
279
+ int lv = ((t2==TRIT_POS)-(t2==TRIT_NEG))*9 + ((t1==TRIT_POS)-(t1==TRIT_NEG))*3
280
+ + ((t0==TRIT_POS)-(t0==TRIT_NEG));
281
+ acc += lv * __ldg(&x[idx]);
282
+ }
283
+ return acc;
284
+ }
285
+
286
+ __global__ void k_v9(
287
+ const uint32_t* __restrict__ pt, const float* __restrict__ sc,
288
+ const float* __restrict__ x, float* __restrict__ y,
289
+ int in_f, int out_f, int depth
290
+ ) {
291
+ __shared__ float parts[V9R * V9W];
292
+ int base = blockIdx.x * V9R;
293
+ int wid = threadIdx.x / WARP_SIZE, lane = threadIdx.x % WARP_SIZE;
294
+ int lr = wid / V9W, rw = wid % V9W, row = base + lr;
295
+ int ng = in_f / GROUP_SIZE;
296
+ const int w = 13;
297
+
298
+ int half = lane / 16;
299
+ int tid_in_group = lane % 16;
300
+
301
+ float partial = 0.0f;
302
+ if (row < out_f) {
303
+ for (int g_pair = rw; g_pair < (ng + 1) / 2; g_pair += V9W) {
304
+ int g = g_pair * 2 + half;
305
+ if (g < ng) {
306
+ float ga = mac_wide_d3(&pt[(row * ng + g) * w],
307
+ &x[g * GROUP_SIZE], tid_in_group);
308
+ unsigned mask = half ? 0xFFFF0000u : 0x0000FFFFu;
309
+ #pragma unroll
310
+ for (int o = 8; o > 0; o >>= 1)
311
+ ga += __shfl_down_sync(mask, ga, o);
312
+ if (tid_in_group == 0)
313
+ partial += ga * __ldg(&sc[row * ng + g]);
314
+ }
315
+ }
316
+ }
317
+
318
+ float my_partial = (lane == 0 || lane == 16) ? partial : 0.0f;
319
+ my_partial += __shfl_xor_sync(0xFFFFFFFF, my_partial, 16);
320
+ if (lane == 0) parts[wid] = my_partial;
321
+ __syncthreads();
322
+ if (lane == 0 && rw == 0 && row < out_f) {
323
+ float s = 0;
324
+ for (int i = 0; i < V9W; i++) s += parts[lr * V9W + i];
325
+ y[row] = s;
326
+ }
327
+ }
328
+
329
+ // ============================================================
330
+ // L2 persistence helpers
331
+ // ============================================================
332
+
333
+ static void set_l2_persist(void* ptr, size_t bytes) {
334
+ cudaStreamAttrValue attr;
335
+ attr.accessPolicyWindow.base_ptr = ptr;
336
+ attr.accessPolicyWindow.num_bytes = bytes;
337
+ attr.accessPolicyWindow.hitRatio = 1.0f;
338
+ attr.accessPolicyWindow.hitProp = cudaAccessPropertyPersisting;
339
+ attr.accessPolicyWindow.missProp = cudaAccessPropertyStreaming;
340
+ cudaStreamSetAttribute(0, cudaStreamAttributeAccessPolicyWindow, &attr);
341
+ }
342
+
343
+ static void clear_l2_persist() {
344
+ cudaStreamAttrValue attr;
345
+ memset(&attr, 0, sizeof(attr));
346
+ cudaStreamSetAttribute(0, cudaStreamAttributeAccessPolicyWindow, &attr);
347
+ }
348
+
349
+ // ============================================================
350
+ // C API — callable from any language via dlopen/ctypes/FFI
351
+ // ============================================================
352
+
353
+ extern "C" {
354
+
355
+ // v27: d2 int4-packed + dp4a (champion for GPU)
356
+ // pt: [rows * ng * 8] int32 (int4 packed weights)
357
+ // ws: [rows * ng] float32 (weight scales)
358
+ // xt: [ng * 16] int32 (int8 packed activations)
359
+ // xs: [ng] float32 (activation scales)
360
+ // y: [rows] float32 (output)
361
+ void trit_gemv_d2_dp4a(
362
+ const int32_t* pt, const float* ws,
363
+ const int32_t* xt, const float* xs,
364
+ float* y, int cols, int rows, int num_groups,
365
+ int use_l2_persist
366
+ ) {
367
+ if (use_l2_persist) {
368
+ set_l2_persist((void*)pt, (size_t)rows * num_groups * 8 * sizeof(int32_t));
369
+ }
370
+ k_v27<<<(rows + V27_RPB - 1) / V27_RPB, V27_BS>>>(
371
+ (const uint32_t*)pt, ws, (const uint32_t*)xt, xs, y, cols, rows, num_groups);
372
+ if (use_l2_persist) {
373
+ clear_l2_persist();
374
+ }
375
+ }
376
+
377
+ // v9: trit-packed d3 (for native trit format)
378
+ // pt: [rows * ng * 13] int32 (trit packed weights)
379
+ // sc: [rows * ng] float32 (scales)
380
+ // x: [cols] float32 (activations)
381
+ // y: [rows] float32 (output)
382
+ void trit_gemv_d3_native(
383
+ const int32_t* pt, const float* sc,
384
+ const float* x, float* y,
385
+ int cols, int rows, int depth
386
+ ) {
387
+ k_v9<<<(rows + V9R - 1) / V9R, V9BS>>>(
388
+ (const uint32_t*)pt, sc, x, y, cols, rows, depth);
389
+ }
390
+
391
+ // v29: d2 unsigned int4 + bias trick (no sign extension)
392
+ void trit_gemv_d2_bias(
393
+ const int32_t* pt, const float* ws,
394
+ const int32_t* xt_e, const int32_t* xt_o,
395
+ const int32_t* x_bias, const float* xs,
396
+ float* y, int cols, int rows, int num_groups,
397
+ int use_l2_persist
398
+ ) {
399
+ if (use_l2_persist) {
400
+ set_l2_persist((void*)pt, (size_t)rows * num_groups * 8 * sizeof(int32_t));
401
+ }
402
+ k_v29<<<(rows + V29_RPB - 1) / V29_RPB, V29_BS>>>(
403
+ (const uint32_t*)pt, ws,
404
+ (const uint32_t*)xt_e, (const uint32_t*)xt_o,
405
+ (const int*)x_bias, xs,
406
+ y, cols, rows, num_groups);
407
+ if (use_l2_persist) {
408
+ clear_l2_persist();
409
+ }
410
+ }
411
+
412
+ // v28: d2 int4 + branchless interleaved decode + dp4a
413
+ // xt_e/xt_o: pre-interleaved x (even/odd nibble positions)
414
+ // Each has ng*8 uint32 words (4 int8 values per word, 8 words per group)
415
+ void trit_gemv_d2_fast(
416
+ const int32_t* pt, const float* ws,
417
+ const int32_t* xt_e, const int32_t* xt_o, const float* xs,
418
+ float* y, int cols, int rows, int num_groups,
419
+ int use_l2_persist
420
+ ) {
421
+ if (use_l2_persist) {
422
+ set_l2_persist((void*)pt, (size_t)rows * num_groups * 8 * sizeof(int32_t));
423
+ }
424
+ k_v28<<<(rows + V28_RPB - 1) / V28_RPB, V28_BS>>>(
425
+ (const uint32_t*)pt, ws,
426
+ (const uint32_t*)xt_e, (const uint32_t*)xt_o, xs,
427
+ y, cols, rows, num_groups);
428
+ if (use_l2_persist) {
429
+ clear_l2_persist();
430
+ }
431
+ }
432
+
433
+ // v21f: d3 int8-packed + dp4a (same format as v21f in wrapper)
434
+ // wt: [rows * ng * 16] int32 (int8 packed weight levels)
435
+ // ws: [rows * ng] float32 (weight scales)
436
+ // xt: [ng * 16] int32 (int8 packed activations)
437
+ // xs: [ng] float32 (activation scales)
438
+ #define V21F_RPB 4
439
+ #define V21F_BS (V21F_RPB * WARP_SIZE)
440
+
441
+ __global__ void k_v21f_standalone(
442
+ const uint32_t* __restrict__ wt,
443
+ const float* __restrict__ ws,
444
+ const uint32_t* __restrict__ xt,
445
+ const float* __restrict__ xs,
446
+ float* __restrict__ y,
447
+ int cols, int rows, int num_groups
448
+ ) {
449
+ int wid = threadIdx.x / WARP_SIZE;
450
+ int lane = threadIdx.x % WARP_SIZE;
451
+ int row = blockIdx.x * V21F_RPB + wid;
452
+ if (row >= rows) return;
453
+
454
+ const uint32_t* row_w = &wt[row * num_groups * 16];
455
+ const float* row_ws = &ws[row * num_groups];
456
+ float acc = 0.0f;
457
+ int total_words = num_groups * 16;
458
+
459
+ for (int base = 0; base < total_words; base += WARP_SIZE) {
460
+ int w = base + lane;
461
+ if (w < total_words) {
462
+ uint32_t w_word = __ldg(&row_w[w]);
463
+ uint32_t x_word = __ldg(&xt[w]);
464
+ int dp = __dp4a((int)w_word, (int)x_word, 0);
465
+ int g = w >> 4; // 16 words per group
466
+ acc += (float)dp * __ldg(&row_ws[g]) * __ldg(&xs[g]);
467
+ }
468
+ }
469
+
470
+ #pragma unroll
471
+ for (int o = 16; o > 0; o >>= 1)
472
+ acc += __shfl_down_sync(0xFFFFFFFF, acc, o);
473
+
474
+ if (lane == 0) y[row] = acc;
475
+ }
476
+
477
+ // ============================================================
478
+ // D3 HARDENED: int8 dp4a, 16 RPB, L2 persist, deferred reduction
479
+ //
480
+ // d3 levels: -13 to +13 (27 values), stored as int8 (1 byte each)
481
+ // 16 words per group, 4 int8 values per word = 64 values/group
482
+ // Division by 16 = shift (no div-by-13 problem!)
483
+ //
484
+ // This is the SIMPLEST kernel — no decode at all.
485
+ // The int8 values go DIRECTLY into dp4a.
486
+ // Pure memory-bound on every GPU.
487
+ // ============================================================
488
+
489
+ #define D3H_RPB 16
490
+ #define D3H_WPG 16 // 16 uint32 words per group (4 int8 each = 64 values)
491
+ #define D3H_BS (D3H_RPB * WARP_SIZE)
492
+
493
+ __global__ void k_d3_hardened(
494
+ const uint32_t* __restrict__ wt, // int8 packed: 16 words per group
495
+ const float* __restrict__ ws, // weight scales [rows * ng]
496
+ const uint32_t* __restrict__ xt, // x int8 packed: 16 words per group
497
+ const float* __restrict__ xs, // x scales [ng]
498
+ float* __restrict__ y,
499
+ int cols, int rows, int num_groups
500
+ ) {
501
+ int wid = threadIdx.x / WARP_SIZE;
502
+ int lane = threadIdx.x % WARP_SIZE;
503
+ int row = blockIdx.x * D3H_RPB + wid;
504
+ if (row >= rows) return;
505
+
506
+ const uint32_t* row_w = &wt[row * num_groups * D3H_WPG];
507
+ const float* row_ws = &ws[row * num_groups];
508
+
509
+ float acc = 0.0f;
510
+ int total_words = num_groups * D3H_WPG;
511
+
512
+ for (int base = 0; base < total_words; base += WARP_SIZE) {
513
+ int w = base + lane;
514
+ if (w < total_words) {
515
+ // COALESCED load — 32 threads × 4 bytes = 128 bytes
516
+ uint32_t w_word = __ldg(&row_w[w]);
517
+ uint32_t x_word = __ldg(&xt[w]);
518
+
519
+ // dp4a: 4× int8 multiply-accumulate — ZERO decode
520
+ int dp = __dp4a((int)w_word, (int)x_word, 0);
521
+
522
+ // Group index: SHIFT (16 is power of 2!)
523
+ int g = w >> 4;
524
+
525
+ // Deferred: accumulate per-thread, ONE reduction at end
526
+ acc += (float)dp * __ldg(&row_ws[g]) * __ldg(&xs[g]);
527
+ }
528
+ }
529
+
530
+ // ONE warp reduction
531
+ #pragma unroll
532
+ for (int o = 16; o > 0; o >>= 1)
533
+ acc += __shfl_down_sync(0xFFFFFFFF, acc, o);
534
+
535
+ if (lane == 0) y[row] = acc;
536
+ }
537
+
538
+ // d3 hardened: uses k_d3_hardened (16 RPB, deferred reduction)
539
+ void trit_gemv_d3_int8_dp4a(
540
+ const int32_t* wt, const float* ws,
541
+ const int32_t* xt, const float* xs,
542
+ float* y, int cols, int rows, int num_groups,
543
+ int use_l2_persist
544
+ ) {
545
+ if (use_l2_persist) {
546
+ set_l2_persist((void*)wt, (size_t)rows * num_groups * 16 * sizeof(int32_t));
547
+ }
548
+ k_d3_hardened<<<(rows + D3H_RPB - 1) / D3H_RPB, D3H_BS>>>(
549
+ (const uint32_t*)wt, ws, (const uint32_t*)xt, xs, y, cols, rows, num_groups);
550
+ if (use_l2_persist) {
551
+ clear_l2_persist();
552
+ }
553
+ }
554
+
555
+ // Run the same layer N times back-to-back to measure pipeline / L2 reuse benefit
556
+ void trit_gemv_pipeline_bench(
557
+ const int32_t* pt, const float* ws,
558
+ const int32_t* xt_e, const int32_t* xt_o, const float* xs,
559
+ float* y, int cols, int rows, int num_groups,
560
+ int n_repeats, int use_l2_persist
561
+ ) {
562
+ if (use_l2_persist) {
563
+ set_l2_persist((void*)pt, (size_t)rows * num_groups * 8 * sizeof(int32_t));
564
+ }
565
+ // Launch n_repeats sequential v28 kernels in the SAME stream
566
+ // This measures the pipeline benefit: back-to-back launches share L2
567
+ for (int i = 0; i < n_repeats; i++) {
568
+ k_v28<<<(rows + V28_RPB - 1) / V28_RPB, V28_BS>>>(
569
+ (const uint32_t*)pt, ws,
570
+ (const uint32_t*)xt_e, (const uint32_t*)xt_o, xs,
571
+ y, cols, rows, num_groups);
572
+ }
573
+ if (use_l2_persist) {
574
+ clear_l2_persist();
575
+ }
576
+ }
577
+
578
+ // Query L2 cache size (for deciding whether to use L2 persist)
579
+ int get_l2_cache_bytes() {
580
+ cudaDeviceProp prop;
581
+ cudaGetDeviceProperties(&prop, 0);
582
+ return prop.l2CacheSize;
583
+ }
584
+
585
+ // Query GPU name
586
+ void get_gpu_name(char* buf, int buflen) {
587
+ cudaDeviceProp prop;
588
+ cudaGetDeviceProperties(&prop, 0);
589
+ strncpy(buf, prop.name, buflen - 1);
590
+ buf[buflen - 1] = '\0';
591
+ }
592
+
593
+ // Synchronize (for timing from Python)
594
+ void cuda_sync() {
595
+ cudaDeviceSynchronize();
596
+ }
597
+
598
+ } // extern "C"