initial public release: code, README, KNOWN_ISSUES
Browse files- KNOWN_ISSUES.md +64 -0
- README.md +116 -0
- build.sh +52 -0
- trit_gemv.cu +292 -0
- trit_gemv_lib.py +280 -0
- trit_gemv_standalone.cu +598 -0
KNOWN_ISSUES.md
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Known issues — tritllm-kernel
|
| 2 |
+
|
| 3 |
+
Surfaced during a pre-release code review. None affect the published paper benchmark numbers (those were obtained on shapes that respect the contract), but anyone using these kernels with new shapes, custom launch parameters, or as a drop-in inference primitive should be aware.
|
| 4 |
+
|
| 5 |
+
## BLOCKER — must respect or fix before relying on the kernel
|
| 6 |
+
|
| 7 |
+
### 1. Implicit one-warp-per-block launch contract
|
| 8 |
+
**Where:** [`trit_gemv.cu:190-237` (`trit_gemv_uniform`)](trit_gemv.cu#L190), [`trit_gemv.cu:245-290` (`trit_gemv_variable`)](trit_gemv.cu#L245)
|
| 9 |
+
|
| 10 |
+
The kernels use `lane = threadIdx.x` directly as the lane index and reduce with a full-warp mask `__shfl_down_sync(0xFFFFFFFF, ...)`. This is correct only when `blockDim.x == 32`.
|
| 11 |
+
|
| 12 |
+
If launched with `blockDim.x > 32`:
|
| 13 |
+
- Threads with `threadIdx.x >= 32` will compute `idx = lane*2+i` past the 64-element group bound and read out-of-bounds.
|
| 14 |
+
- All threads with lane 0 across multiple warps race to write `y[row]`.
|
| 15 |
+
|
| 16 |
+
**Fix in caller:** always launch with `blockDim.x == 32`. The host-side wrappers in `trit_gemv_standalone.cu` do this correctly. Direct callers from custom code must respect it.
|
| 17 |
+
|
| 18 |
+
**Future fix in kernel:** add `assert(blockDim.x == WARP_SIZE)` at kernel entry, or rewrite to handle multi-warp blocks correctly.
|
| 19 |
+
|
| 20 |
+
### 2. `in_features` not a multiple of `GROUP_SIZE` is silently dropped
|
| 21 |
+
**Where:** [`trit_gemv.cu:194`](trit_gemv.cu#L194), [`trit_gemv.cu:259`](trit_gemv.cu#L259)
|
| 22 |
+
|
| 23 |
+
```cpp
|
| 24 |
+
int num_groups = in_features / GROUP_SIZE;
|
| 25 |
+
```
|
| 26 |
+
|
| 27 |
+
Integer division truncates. If `in_features % 64 != 0`, the trailing partial group is silently skipped and that fragment of the dot product is missing from the output.
|
| 28 |
+
|
| 29 |
+
**Fix in caller:** pad the input weight matrix (and activations) with zero rows to the next multiple of 64 before quantizing. The codec output already does this for Qwen, Llama, and Mistral architectures, all of which have `hidden_dim` divisible by 64.
|
| 30 |
+
|
| 31 |
+
**Future fix in kernel:** add `assert(in_features % GROUP_SIZE == 0)` at kernel entry, or write a tail-handling path.
|
| 32 |
+
|
| 33 |
+
## SHOULD-FIX
|
| 34 |
+
|
| 35 |
+
### 3. C API performs no input validation
|
| 36 |
+
**Where:** `trit_gemv_standalone.cu`, all `extern "C"` functions
|
| 37 |
+
|
| 38 |
+
`trit_gemv_d2_fast`, `trit_gemv_d2_dp4a`, `trit_gemv_d3_native`, etc. accept null pointers, mismatched `rows`/`cols`/`num_groups`, and incorrectly packed buffers without complaint. Bad inputs become device faults or OOB reads.
|
| 39 |
+
|
| 40 |
+
For a public ctypes-facing library this is sharp. We will add a validation pass in a future revision; for now, callers must guarantee their arguments.
|
| 41 |
+
|
| 42 |
+
### 4. `get_gpu_name(char* buf, int buflen)` has no null/length guard
|
| 43 |
+
**Where:** [`trit_gemv_standalone.cu:700`](trit_gemv_standalone.cu#L700)
|
| 44 |
+
|
| 45 |
+
Calling with `buf == nullptr` or `buflen <= 0` is immediate UB on the host side. Trivial fix; pending.
|
| 46 |
+
|
| 47 |
+
### 5. CUDA error returns are not surfaced
|
| 48 |
+
**Where:** several places in `trit_gemv_standalone.cu` where `set_l2_persist`, kernel launches, and helper calls drop `cudaError_t` returns
|
| 49 |
+
|
| 50 |
+
If a kernel launch fails (e.g., bad shapes that pass the (missing) input validation), the failure is silent until the next `cudaDeviceSynchronize()` or `cudaGetLastError()`. The public functions return `void` and have no error-reporting path.
|
| 51 |
+
|
| 52 |
+
Workaround: call `cuda_sync()` after each operation and check `cudaGetLastError()` from your wrapper.
|
| 53 |
+
|
| 54 |
+
### 6. Reduction wastes 31 lanes per group
|
| 55 |
+
**Where:** [`trit_gemv.cu:223-232`](trit_gemv.cu#L223), [`trit_gemv.cu:279-286`](trit_gemv.cu#L279)
|
| 56 |
+
|
| 57 |
+
After the warp reduction, only lane 0 multiplies by the group scale and accumulates into `row_acc`. The other 31 lanes idle for the scale/add path. This is correct, just leaves performance on the table relative to the deferred-reduction design used in `k_d3_hardened` (`trit_gemv_standalone.cu:493`).
|
| 58 |
+
|
| 59 |
+
The headline 7.8× number is from the deferred-reduction path, so this only matters if you use the educational `trit_gemv_uniform` / `trit_gemv_variable` kernels directly.
|
| 60 |
+
|
| 61 |
+
## NIT
|
| 62 |
+
|
| 63 |
+
### 7. Multiple prototype kernels in production file
|
| 64 |
+
`trit_gemv_standalone.cu` contains v9, v27, v28, v29, `k_d3_hardened`, plus the non-deferred kernels — a development history rather than a clean public surface. The `k_v29_pipeline` / `trit_pipeline` path was broken (passed nullptr for required arrays) and was removed in commit prior to this release. The remaining prototypes (`k_v27`, `k_v29`, `k_v28`) are still wired through public C functions; they work, but the API surface is wider than needed. A future revision will trim to one canonical entry per depth.
|
README.md
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: apache-2.0
|
| 3 |
+
tags:
|
| 4 |
+
- cuda
|
| 5 |
+
- quantization
|
| 6 |
+
- ternary
|
| 7 |
+
- llm-inference
|
| 8 |
+
- kernel
|
| 9 |
+
---
|
| 10 |
+
|
| 11 |
+
# tritllm-kernel
|
| 12 |
+
|
| 13 |
+
Multiply-free ternary GEMV CUDA kernel for the codec from
|
| 14 |
+
**"Balanced Ternary Post-Training Quantization for Large Language Models"** (Stentzel, 2026).
|
| 15 |
+
|
| 16 |
+
The headline number from the paper: **7.8× speedup** over cuBLAS FP16 GEMV on RTX 4090 in the memory-bound regime, projected to full-model token generation throughput from per-layer benchmarks.
|
| 17 |
+
|
| 18 |
+
> **These are kernel-only projections, not end-to-end serving throughput.** They exclude attention, KV cache, sampling, and tokenizer overhead. See Section 7 of the paper for methodology.
|
| 19 |
+
|
| 20 |
+
## What it is
|
| 21 |
+
|
| 22 |
+
A standalone CUDA shared library (`libtrit_gemv.so` / `.dll`) callable via `ctypes` from any language, with no PyTorch dependency. The same algorithm is also wrapped via PyTorch's pybind11 in `trit_gemv_wrapper.cu` for benchmarking.
|
| 23 |
+
|
| 24 |
+
The core trick: each ternary weight (-1, 0, +1) reduces a multiply-accumulate to a conditional add/subtract/skip. The kernel uses Ada/Hopper/Blackwell `dp4a` intrinsics on int4-packed weights and pre-interleaved int8 activations to do four ternary-times-int8 dot products per instruction.
|
| 25 |
+
|
| 26 |
+
## Build
|
| 27 |
+
|
| 28 |
+
```bash
|
| 29 |
+
cd kernel
|
| 30 |
+
./build.sh
|
| 31 |
+
```
|
| 32 |
+
|
| 33 |
+
The build script targets SM 70/75/80/86/89/90/100/120 in one fat binary so the `.so` runs on V100, T4, A100, RTX 30/40/50, H100, and B100/B200 without recompilation.
|
| 34 |
+
|
| 35 |
+
Required: `nvcc` (CUDA 11.8 or newer) and a C++ compiler.
|
| 36 |
+
|
| 37 |
+
## Performance (Qwen2.5-7B, d=2, 3.47 bpw, 3.3 GB model)
|
| 38 |
+
|
| 39 |
+
| GPU | L2 cache | Tokens/sec | Speedup vs FP16 cuBLAS | Effective BW |
|
| 40 |
+
|---|---|---|---|---|
|
| 41 |
+
| RTX 4090 | 72 MB | 588 | 7.8× | 1940 GB/s |
|
| 42 |
+
| RTX 3090 | 6 MB | 192 | 3.4× | 633 GB/s |
|
| 43 |
+
| RTX 4080 Laptop | 64 MB | 133 | 5.8× | 439 GB/s |
|
| 44 |
+
| A100 80GB | 40 MB | 201 | 4.2× | 663 GB/s |
|
| 45 |
+
|
| 46 |
+
These are per-layer GEMV benchmarks projected to full-model token-generation throughput. The L2-cache size correlates strongly with speedup because each `d=2` layer fits in L2 on the RTX 4090, giving an effective bandwidth roughly 2× HBM bandwidth.
|
| 47 |
+
|
| 48 |
+
See `kernel/bench_*.py` for the benchmark drivers.
|
| 49 |
+
|
| 50 |
+
## Launch contract
|
| 51 |
+
|
| 52 |
+
The kernels in `trit_gemv.cu` and `trit_gemv_standalone.cu` assume:
|
| 53 |
+
|
| 54 |
+
| Constraint | Why | What happens if violated |
|
| 55 |
+
|---|---|---|
|
| 56 |
+
| `blockDim.x == 32` (one warp per block) | Kernels use `__shfl_down_sync(0xFFFFFFFF, ...)` and lane-0 reduction | OOB index reads + race on `y[row]` |
|
| 57 |
+
| `in_features % 64 == 0` | Group size is fixed at 64 weights | Trailing partial group is silently dropped — incorrect output for that row |
|
| 58 |
+
| Weight, scale, and activation buffers are device-resident and properly aligned | Kernel uses `__ldg` for cached loads | UB / device fault |
|
| 59 |
+
|
| 60 |
+
If your model has `in_features` not divisible by 64, pad the weight matrix to the next multiple of 64 with zero rows before quantizing.
|
| 61 |
+
|
| 62 |
+
## API surface
|
| 63 |
+
|
| 64 |
+
C ABI in `trit_gemv_standalone.cu`:
|
| 65 |
+
|
| 66 |
+
```c
|
| 67 |
+
// Best-tested d=2 path (champion for 4090)
|
| 68 |
+
void trit_gemv_d2_fast(
|
| 69 |
+
const int32_t* pt, // [rows * num_groups * 8] int4-packed weights
|
| 70 |
+
const float* ws, // [rows * num_groups] scales
|
| 71 |
+
const int32_t* xt_e, // [num_groups * 8] even nibble activations
|
| 72 |
+
const int32_t* xt_o, // [num_groups * 8] odd nibble activations
|
| 73 |
+
const float* xs, // [num_groups] activation scales
|
| 74 |
+
float* y, // [rows] output
|
| 75 |
+
int cols, int rows, int num_groups,
|
| 76 |
+
int use_l2_persist // 0 = off, 1 = enable L2 persistence
|
| 77 |
+
);
|
| 78 |
+
|
| 79 |
+
// Native-trit packed d=3 (no int4 intermediate)
|
| 80 |
+
void trit_gemv_d3_native(
|
| 81 |
+
const int32_t* pt, // [rows * num_groups * 13] trit-packed
|
| 82 |
+
const float* sc,
|
| 83 |
+
const float* x,
|
| 84 |
+
float* y,
|
| 85 |
+
int cols, int rows, int depth
|
| 86 |
+
);
|
| 87 |
+
|
| 88 |
+
// L2 cache size query (for deciding whether to enable persist)
|
| 89 |
+
int get_l2_cache_bytes();
|
| 90 |
+
void get_gpu_name(char* buf, int buflen);
|
| 91 |
+
void cuda_sync();
|
| 92 |
+
```
|
| 93 |
+
|
| 94 |
+
## Known issues
|
| 95 |
+
|
| 96 |
+
Documented in [KNOWN_ISSUES.md](KNOWN_ISSUES.md). Summary:
|
| 97 |
+
|
| 98 |
+
- **Launch contract is implicit, not enforced.** Kernels are correct only with `blockDim.x == 32`. There are no runtime asserts; the contract is guarded only by the host-side wrappers in this file. Direct callers must respect it.
|
| 99 |
+
- **`in_features` not a multiple of 64 silently fails.** No assert. Pad your matrix.
|
| 100 |
+
- **C API has no input validation.** Null pointers, wrong dimensions, and buffer-shape mismatches become device faults or OOB reads. This is a public-API hardening item we have not yet completed.
|
| 101 |
+
- **CUDA error returns are not surfaced to the caller** in some helper paths. If a kernel launch fails, `cuda_sync()` will see it but the public functions return `void`.
|
| 102 |
+
|
| 103 |
+
## Citation
|
| 104 |
+
|
| 105 |
+
```
|
| 106 |
+
@article{stentzel2026ternaryptq,
|
| 107 |
+
title = {Balanced Ternary Post-Training Quantization for Large Language Models},
|
| 108 |
+
author = {Stentzel, Eric},
|
| 109 |
+
year = 2026,
|
| 110 |
+
note = {Entrit Systems}
|
| 111 |
+
}
|
| 112 |
+
```
|
| 113 |
+
|
| 114 |
+
## License
|
| 115 |
+
|
| 116 |
+
Apache-2.0.
|
build.sh
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
# Build libtrit_gemv.so — standalone CUDA kernel library
|
| 3 |
+
# No PyTorch, no Python, no framework dependency.
|
| 4 |
+
# Just nvcc + CUDA runtime.
|
| 5 |
+
#
|
| 6 |
+
# Fat binary: compiles for all major GPU architectures.
|
| 7 |
+
# The right kernel is selected at runtime based on the GPU.
|
| 8 |
+
|
| 9 |
+
set -e
|
| 10 |
+
cd "$(dirname "$0")"
|
| 11 |
+
|
| 12 |
+
# Detect nvcc
|
| 13 |
+
NVCC=$(which nvcc 2>/dev/null || echo "/usr/local/cuda/bin/nvcc")
|
| 14 |
+
if [ ! -x "$NVCC" ]; then
|
| 15 |
+
echo "ERROR: nvcc not found. Install CUDA toolkit."
|
| 16 |
+
exit 1
|
| 17 |
+
fi
|
| 18 |
+
|
| 19 |
+
echo "Using nvcc: $NVCC"
|
| 20 |
+
$NVCC --version | head -1
|
| 21 |
+
|
| 22 |
+
# Architecture targets (fat binary)
|
| 23 |
+
# Volta (V100), Turing (2080), Ampere (3090/A100),
|
| 24 |
+
# Ada (4080/4090), Hopper (H100), Blackwell (5070+)
|
| 25 |
+
ARCHS=""
|
| 26 |
+
ARCHS="$ARCHS -gencode=arch=compute_70,code=sm_70" # V100
|
| 27 |
+
ARCHS="$ARCHS -gencode=arch=compute_75,code=sm_75" # 2080
|
| 28 |
+
ARCHS="$ARCHS -gencode=arch=compute_80,code=sm_80" # A100, 3080
|
| 29 |
+
ARCHS="$ARCHS -gencode=arch=compute_86,code=sm_86" # 3090
|
| 30 |
+
ARCHS="$ARCHS -gencode=arch=compute_89,code=sm_89" # 4080, 4090
|
| 31 |
+
ARCHS="$ARCHS -gencode=arch=compute_90,code=sm_90" # H100
|
| 32 |
+
|
| 33 |
+
# Blackwell — only if nvcc supports it (CUDA 12.8+)
|
| 34 |
+
if $NVCC --help 2>&1 | grep -q "compute_120"; then
|
| 35 |
+
ARCHS="$ARCHS -gencode=arch=compute_120,code=sm_120"
|
| 36 |
+
echo "Including Blackwell (sm_120)"
|
| 37 |
+
fi
|
| 38 |
+
|
| 39 |
+
echo "Building libtrit_gemv.so..."
|
| 40 |
+
$NVCC -O3 --use_fast_math \
|
| 41 |
+
-shared -Xcompiler -fPIC \
|
| 42 |
+
$ARCHS \
|
| 43 |
+
-o libtrit_gemv.so \
|
| 44 |
+
trit_gemv_standalone.cu
|
| 45 |
+
|
| 46 |
+
ls -la libtrit_gemv.so
|
| 47 |
+
echo "Done! Library ready at $(pwd)/libtrit_gemv.so"
|
| 48 |
+
echo ""
|
| 49 |
+
echo "Usage from Python:"
|
| 50 |
+
echo " from trit_gemv_lib import TritGEMV"
|
| 51 |
+
echo " lib = TritGEMV()"
|
| 52 |
+
echo " lib.gemv_d2(weights, scales, x_int8, x_scales, output, K, M, ng)"
|
trit_gemv.cu
ADDED
|
@@ -0,0 +1,292 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*
|
| 2 |
+
* TritLLM CUDA Kernel — Ternary GEMV (Matrix-Vector Multiply)
|
| 3 |
+
*
|
| 4 |
+
* Core operation: y = W_ternary @ x
|
| 5 |
+
* Where W_ternary is packed ternary weights with per-group scales.
|
| 6 |
+
*
|
| 7 |
+
* Each group of 64 weights has:
|
| 8 |
+
* - A depth (1-4 trits per weight)
|
| 9 |
+
* - A FP16 scale factor
|
| 10 |
+
* - Packed trit values (2 bits per trit: 00=0, 01=+1, 10=-1, 11=unused)
|
| 11 |
+
*
|
| 12 |
+
* The key: NO floating-point multiply in the inner loop.
|
| 13 |
+
* Ternary MAC = conditional add/subtract.
|
| 14 |
+
*/
|
| 15 |
+
|
| 16 |
+
#include <cuda_fp16.h>
|
| 17 |
+
#include <cuda_runtime.h>
|
| 18 |
+
#include <stdint.h>
|
| 19 |
+
|
| 20 |
+
#define GROUP_SIZE 64
|
| 21 |
+
#define WARP_SIZE 32
|
| 22 |
+
|
| 23 |
+
// Trit encoding: 2 bits per trit
|
| 24 |
+
// 00 = 0, 01 = +1, 10 = -1
|
| 25 |
+
#define TRIT_ZERO 0
|
| 26 |
+
#define TRIT_POS 1
|
| 27 |
+
#define TRIT_NEG 2
|
| 28 |
+
|
| 29 |
+
/*
|
| 30 |
+
* Depth 1 (3 levels: {-1, 0, +1}): 1 trit per weight, 2 bits per weight
|
| 31 |
+
* Pack 16 trits per uint32 (16 * 2 = 32 bits)
|
| 32 |
+
* Group of 64 = 4 uint32s
|
| 33 |
+
*
|
| 34 |
+
* Inner loop: read trit, branch-free conditional accumulate
|
| 35 |
+
*/
|
| 36 |
+
__device__ __forceinline__ float trit_mac_d1(
|
| 37 |
+
const uint32_t* __restrict__ packed, // 4 uint32s = 64 trits
|
| 38 |
+
const float* __restrict__ x, // 64 activations
|
| 39 |
+
int lane // warp lane (0-31)
|
| 40 |
+
) {
|
| 41 |
+
float acc = 0.0f;
|
| 42 |
+
|
| 43 |
+
// Each thread in warp handles 2 elements (64 / 32 = 2)
|
| 44 |
+
#pragma unroll
|
| 45 |
+
for (int i = 0; i < 2; i++) {
|
| 46 |
+
int idx = lane * 2 + i;
|
| 47 |
+
int word = idx / 16; // which uint32 (0-3)
|
| 48 |
+
int bit_offset = (idx % 16) * 2; // bit position within word
|
| 49 |
+
|
| 50 |
+
uint32_t trit = (packed[word] >> bit_offset) & 0x3;
|
| 51 |
+
float val = x[idx];
|
| 52 |
+
|
| 53 |
+
// Branch-free: acc += (trit == 1) * val - (trit == 2) * val
|
| 54 |
+
acc += ((trit == TRIT_POS) - (trit == TRIT_NEG)) * val;
|
| 55 |
+
}
|
| 56 |
+
|
| 57 |
+
return acc;
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
+
/*
|
| 61 |
+
* Depth 2 (9 levels: {-4..+4}): 2 trits per weight, 4 bits per weight
|
| 62 |
+
* Trit value = trit1 * 3 + trit0 - 4 (maps to -4..+4)
|
| 63 |
+
* Pack 8 values per uint32 (8 * 4 = 32 bits)
|
| 64 |
+
* Group of 64 = 8 uint32s
|
| 65 |
+
*/
|
| 66 |
+
__device__ __forceinline__ float trit_mac_d2(
|
| 67 |
+
const uint32_t* __restrict__ packed, // 8 uint32s = 64 values
|
| 68 |
+
const float* __restrict__ x,
|
| 69 |
+
int lane
|
| 70 |
+
) {
|
| 71 |
+
float acc = 0.0f;
|
| 72 |
+
|
| 73 |
+
#pragma unroll
|
| 74 |
+
for (int i = 0; i < 2; i++) {
|
| 75 |
+
int idx = lane * 2 + i;
|
| 76 |
+
int word = idx / 8;
|
| 77 |
+
int bit_offset = (idx % 8) * 4;
|
| 78 |
+
|
| 79 |
+
uint32_t bits = (packed[word] >> bit_offset) & 0xF;
|
| 80 |
+
// Decode: trit1 = bits >> 2, trit0 = bits & 0x3
|
| 81 |
+
// value = (trit1_sign * 3 + trit0_sign)
|
| 82 |
+
// where trit_sign: 00->0, 01->+1, 10->-1
|
| 83 |
+
int t0 = (int)(bits & 0x3);
|
| 84 |
+
int t1 = (int)((bits >> 2) & 0x3);
|
| 85 |
+
int sign0 = (t0 == TRIT_POS) - (t0 == TRIT_NEG);
|
| 86 |
+
int sign1 = (t1 == TRIT_POS) - (t1 == TRIT_NEG);
|
| 87 |
+
int level = sign1 * 3 + sign0; // -4 to +4
|
| 88 |
+
|
| 89 |
+
// Still no FP multiply — integer * float is one instruction
|
| 90 |
+
// level is small integer, compiler optimizes to repeated add
|
| 91 |
+
acc += level * x[idx];
|
| 92 |
+
}
|
| 93 |
+
|
| 94 |
+
return acc;
|
| 95 |
+
}
|
| 96 |
+
|
| 97 |
+
/*
|
| 98 |
+
* Depth 3 (27 levels: {-13..+13}): 3 trits per weight, 6 bits per weight
|
| 99 |
+
* Pack 5 values per uint32 (5 * 6 = 30 bits, 2 wasted)
|
| 100 |
+
* Group of 64 = 13 uint32s (64 values, last uint32 has 4 values)
|
| 101 |
+
*/
|
| 102 |
+
__device__ __forceinline__ float trit_mac_d3(
|
| 103 |
+
const uint32_t* __restrict__ packed, // 13 uint32s
|
| 104 |
+
const float* __restrict__ x,
|
| 105 |
+
int lane
|
| 106 |
+
) {
|
| 107 |
+
float acc = 0.0f;
|
| 108 |
+
|
| 109 |
+
#pragma unroll
|
| 110 |
+
for (int i = 0; i < 2; i++) {
|
| 111 |
+
int idx = lane * 2 + i;
|
| 112 |
+
int word = idx / 5;
|
| 113 |
+
int pos = idx % 5;
|
| 114 |
+
int bit_offset = pos * 6;
|
| 115 |
+
|
| 116 |
+
uint32_t bits = (packed[word] >> bit_offset) & 0x3F;
|
| 117 |
+
int t0 = (int)(bits & 0x3);
|
| 118 |
+
int t1 = (int)((bits >> 2) & 0x3);
|
| 119 |
+
int t2 = (int)((bits >> 4) & 0x3);
|
| 120 |
+
int s0 = (t0 == TRIT_POS) - (t0 == TRIT_NEG);
|
| 121 |
+
int s1 = (t1 == TRIT_POS) - (t1 == TRIT_NEG);
|
| 122 |
+
int s2 = (t2 == TRIT_POS) - (t2 == TRIT_NEG);
|
| 123 |
+
int level = s2 * 9 + s1 * 3 + s0; // -13 to +13
|
| 124 |
+
|
| 125 |
+
acc += level * x[idx];
|
| 126 |
+
}
|
| 127 |
+
|
| 128 |
+
return acc;
|
| 129 |
+
}
|
| 130 |
+
|
| 131 |
+
/*
|
| 132 |
+
* Depth 4 (81 levels: {-40..+40}): 4 trits per weight, 8 bits per weight
|
| 133 |
+
* Pack 4 values per uint32 (4 * 8 = 32 bits, perfect)
|
| 134 |
+
* Group of 64 = 16 uint32s
|
| 135 |
+
*/
|
| 136 |
+
__device__ __forceinline__ float trit_mac_d4(
|
| 137 |
+
const uint32_t* __restrict__ packed, // 16 uint32s
|
| 138 |
+
const float* __restrict__ x,
|
| 139 |
+
int lane
|
| 140 |
+
) {
|
| 141 |
+
float acc = 0.0f;
|
| 142 |
+
|
| 143 |
+
#pragma unroll
|
| 144 |
+
for (int i = 0; i < 2; i++) {
|
| 145 |
+
int idx = lane * 2 + i;
|
| 146 |
+
int word = idx / 4;
|
| 147 |
+
int bit_offset = (idx % 4) * 8;
|
| 148 |
+
|
| 149 |
+
uint32_t bits = (packed[word] >> bit_offset) & 0xFF;
|
| 150 |
+
int t0 = (int)(bits & 0x3);
|
| 151 |
+
int t1 = (int)((bits >> 2) & 0x3);
|
| 152 |
+
int t2 = (int)((bits >> 4) & 0x3);
|
| 153 |
+
int t3 = (int)((bits >> 6) & 0x3);
|
| 154 |
+
int s0 = (t0 == TRIT_POS) - (t0 == TRIT_NEG);
|
| 155 |
+
int s1 = (t1 == TRIT_POS) - (t1 == TRIT_NEG);
|
| 156 |
+
int s2 = (t2 == TRIT_POS) - (t2 == TRIT_NEG);
|
| 157 |
+
int s3 = (t3 == TRIT_POS) - (t3 == TRIT_NEG);
|
| 158 |
+
int level = s3 * 27 + s2 * 9 + s1 * 3 + s0;
|
| 159 |
+
|
| 160 |
+
acc += level * x[idx];
|
| 161 |
+
}
|
| 162 |
+
|
| 163 |
+
return acc;
|
| 164 |
+
}
|
| 165 |
+
|
| 166 |
+
/*
|
| 167 |
+
* Main GEMV kernel: y[out_features] = W[out_features, in_features] @ x[in_features]
|
| 168 |
+
*
|
| 169 |
+
* W is stored as packed ternary groups:
|
| 170 |
+
* - packed_trits: variable-length packed trit data per group
|
| 171 |
+
* - scales: FP16 scale per group
|
| 172 |
+
* - depths: uint8 depth per group (1-4)
|
| 173 |
+
* - group_offsets: byte offset into packed_trits for each group
|
| 174 |
+
*
|
| 175 |
+
* One warp per output row, iterating over groups along the input dimension.
|
| 176 |
+
* Warp reduction gives the final dot product.
|
| 177 |
+
*/
|
| 178 |
+
|
| 179 |
+
// Simplified version: uniform depth across all groups in a tensor
|
| 180 |
+
// (variable-depth version below)
|
| 181 |
+
__global__ void trit_gemv_uniform(
|
| 182 |
+
const uint32_t* __restrict__ packed_trits, // packed trit data
|
| 183 |
+
const float* __restrict__ scales, // [num_groups] FP16 stored as float
|
| 184 |
+
const float* __restrict__ x, // [in_features]
|
| 185 |
+
float* __restrict__ y, // [out_features]
|
| 186 |
+
int in_features,
|
| 187 |
+
int out_features,
|
| 188 |
+
int depth // uniform depth 1-4
|
| 189 |
+
) {
|
| 190 |
+
int row = blockIdx.x; // one block per output row
|
| 191 |
+
if (row >= out_features) return;
|
| 192 |
+
|
| 193 |
+
int lane = threadIdx.x; // lane within warp (0-31)
|
| 194 |
+
int num_groups = in_features / GROUP_SIZE;
|
| 195 |
+
|
| 196 |
+
// Words per group depends on depth
|
| 197 |
+
int words_per_group;
|
| 198 |
+
switch (depth) {
|
| 199 |
+
case 1: words_per_group = 4; break; // 64 * 2 / 32
|
| 200 |
+
case 2: words_per_group = 8; break; // 64 * 4 / 32
|
| 201 |
+
case 3: words_per_group = 13; break; // ceil(64 * 6 / 32)
|
| 202 |
+
case 4: words_per_group = 16; break; // 64 * 8 / 32
|
| 203 |
+
default: words_per_group = 4; break;
|
| 204 |
+
}
|
| 205 |
+
|
| 206 |
+
float row_acc = 0.0f;
|
| 207 |
+
|
| 208 |
+
for (int g = 0; g < num_groups; g++) {
|
| 209 |
+
int group_offset = (row * num_groups + g) * words_per_group;
|
| 210 |
+
const uint32_t* group_data = &packed_trits[group_offset];
|
| 211 |
+
const float* group_x = &x[g * GROUP_SIZE];
|
| 212 |
+
float scale = scales[row * num_groups + g];
|
| 213 |
+
|
| 214 |
+
float group_acc;
|
| 215 |
+
switch (depth) {
|
| 216 |
+
case 1: group_acc = trit_mac_d1(group_data, group_x, lane); break;
|
| 217 |
+
case 2: group_acc = trit_mac_d2(group_data, group_x, lane); break;
|
| 218 |
+
case 3: group_acc = trit_mac_d3(group_data, group_x, lane); break;
|
| 219 |
+
case 4: group_acc = trit_mac_d4(group_data, group_x, lane); break;
|
| 220 |
+
default: group_acc = 0.0f; break;
|
| 221 |
+
}
|
| 222 |
+
|
| 223 |
+
// Warp reduction
|
| 224 |
+
#pragma unroll
|
| 225 |
+
for (int offset = 16; offset > 0; offset >>= 1) {
|
| 226 |
+
group_acc += __shfl_down_sync(0xFFFFFFFF, group_acc, offset);
|
| 227 |
+
}
|
| 228 |
+
|
| 229 |
+
// Lane 0 accumulates the scaled result
|
| 230 |
+
if (lane == 0) {
|
| 231 |
+
row_acc += group_acc * scale;
|
| 232 |
+
}
|
| 233 |
+
}
|
| 234 |
+
|
| 235 |
+
// Write output
|
| 236 |
+
if (lane == 0) {
|
| 237 |
+
y[row] = row_acc;
|
| 238 |
+
}
|
| 239 |
+
}
|
| 240 |
+
|
| 241 |
+
/*
|
| 242 |
+
* Variable-depth version: each group can have a different depth.
|
| 243 |
+
* Uses a depth map and offset table to handle mixed-depth tensors.
|
| 244 |
+
*/
|
| 245 |
+
__global__ void trit_gemv_variable(
|
| 246 |
+
const uint32_t* __restrict__ packed_trits,
|
| 247 |
+
const float* __restrict__ scales,
|
| 248 |
+
const uint8_t* __restrict__ depth_map, // [num_groups_per_row] depth per group
|
| 249 |
+
const int* __restrict__ group_offsets, // [num_groups_per_row + 1] word offsets
|
| 250 |
+
const float* __restrict__ x,
|
| 251 |
+
float* __restrict__ y,
|
| 252 |
+
int in_features,
|
| 253 |
+
int out_features
|
| 254 |
+
) {
|
| 255 |
+
int row = blockIdx.x;
|
| 256 |
+
if (row >= out_features) return;
|
| 257 |
+
|
| 258 |
+
int lane = threadIdx.x;
|
| 259 |
+
int num_groups = in_features / GROUP_SIZE;
|
| 260 |
+
|
| 261 |
+
float row_acc = 0.0f;
|
| 262 |
+
|
| 263 |
+
for (int g = 0; g < num_groups; g++) {
|
| 264 |
+
int depth = depth_map[g];
|
| 265 |
+
int word_offset = group_offsets[g] + row * group_offsets[num_groups]; // row stride
|
| 266 |
+
const uint32_t* group_data = &packed_trits[word_offset];
|
| 267 |
+
const float* group_x = &x[g * GROUP_SIZE];
|
| 268 |
+
float scale = scales[row * num_groups + g];
|
| 269 |
+
|
| 270 |
+
float group_acc;
|
| 271 |
+
switch (depth) {
|
| 272 |
+
case 1: group_acc = trit_mac_d1(group_data, group_x, lane); break;
|
| 273 |
+
case 2: group_acc = trit_mac_d2(group_data, group_x, lane); break;
|
| 274 |
+
case 3: group_acc = trit_mac_d3(group_data, group_x, lane); break;
|
| 275 |
+
case 4: group_acc = trit_mac_d4(group_data, group_x, lane); break;
|
| 276 |
+
default: group_acc = 0.0f; break;
|
| 277 |
+
}
|
| 278 |
+
|
| 279 |
+
#pragma unroll
|
| 280 |
+
for (int offset = 16; offset > 0; offset >>= 1) {
|
| 281 |
+
group_acc += __shfl_down_sync(0xFFFFFFFF, group_acc, offset);
|
| 282 |
+
}
|
| 283 |
+
|
| 284 |
+
if (lane == 0) {
|
| 285 |
+
row_acc += group_acc * scale;
|
| 286 |
+
}
|
| 287 |
+
}
|
| 288 |
+
|
| 289 |
+
if (lane == 0) {
|
| 290 |
+
y[row] = row_acc;
|
| 291 |
+
}
|
| 292 |
+
}
|
trit_gemv_lib.py
ADDED
|
@@ -0,0 +1,280 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Framework-agnostic trit GEMV library.
|
| 2 |
+
|
| 3 |
+
Loads the pre-compiled libtrit_gemv.so via ctypes.
|
| 4 |
+
Works with PyTorch, JAX, CuPy, or raw CUDA pointers.
|
| 5 |
+
|
| 6 |
+
Compile the library once:
|
| 7 |
+
cd kernel/
|
| 8 |
+
./build.sh
|
| 9 |
+
|
| 10 |
+
Then use from any framework:
|
| 11 |
+
from trit_gemv_lib import TritGEMV
|
| 12 |
+
lib = TritGEMV()
|
| 13 |
+
|
| 14 |
+
# PyTorch
|
| 15 |
+
lib.gemv_d2(pt_tensor, ws_tensor, xt_tensor, xs_tensor, y_tensor, cols, rows, ng)
|
| 16 |
+
|
| 17 |
+
# Raw pointers (CuPy, JAX, etc.)
|
| 18 |
+
lib.gemv_d2_ptr(pt_ptr, ws_ptr, xt_ptr, xs_ptr, y_ptr, cols, rows, ng)
|
| 19 |
+
"""
|
| 20 |
+
import ctypes
|
| 21 |
+
import os
|
| 22 |
+
import subprocess
|
| 23 |
+
import sys
|
| 24 |
+
|
| 25 |
+
# Find the library
|
| 26 |
+
_LIB_NAMES = ['libtrit_gemv.so', 'libtrit_gemv.dll', 'trit_gemv.so']
|
| 27 |
+
_SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def _find_lib():
|
| 31 |
+
for name in _LIB_NAMES:
|
| 32 |
+
path = os.path.join(_SCRIPT_DIR, name)
|
| 33 |
+
if os.path.exists(path):
|
| 34 |
+
return path
|
| 35 |
+
return None
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def _build_lib():
|
| 39 |
+
"""Auto-compile if not found."""
|
| 40 |
+
build_script = os.path.join(_SCRIPT_DIR, 'build.sh')
|
| 41 |
+
if os.path.exists(build_script):
|
| 42 |
+
print("Building libtrit_gemv.so...", flush=True)
|
| 43 |
+
subprocess.run(['bash', build_script], cwd=_SCRIPT_DIR, check=True)
|
| 44 |
+
else:
|
| 45 |
+
# Inline build
|
| 46 |
+
cu_file = os.path.join(_SCRIPT_DIR, 'trit_gemv_standalone.cu')
|
| 47 |
+
out_file = os.path.join(_SCRIPT_DIR, 'libtrit_gemv.so')
|
| 48 |
+
if not os.path.exists(cu_file):
|
| 49 |
+
raise FileNotFoundError(f"Cannot find {cu_file}")
|
| 50 |
+
|
| 51 |
+
# Detect GPU architecture
|
| 52 |
+
try:
|
| 53 |
+
import torch
|
| 54 |
+
cc = torch.cuda.get_device_capability(0)
|
| 55 |
+
arch = f"compute_{cc[0]}{cc[1]}"
|
| 56 |
+
sm = f"sm_{cc[0]}{cc[1]}"
|
| 57 |
+
gencode = f"-gencode=arch={arch},code={sm}"
|
| 58 |
+
except:
|
| 59 |
+
# Default to common architectures
|
| 60 |
+
gencode = " ".join([
|
| 61 |
+
f"-gencode=arch=compute_{a},code=sm_{a}"
|
| 62 |
+
for a in ["70", "75", "80", "86", "89", "90"]
|
| 63 |
+
])
|
| 64 |
+
|
| 65 |
+
cmd = f"nvcc -O3 --use_fast_math -shared -Xcompiler -fPIC {gencode} -o {out_file} {cu_file}"
|
| 66 |
+
print(f"Compiling: {cmd}", flush=True)
|
| 67 |
+
subprocess.run(cmd, shell=True, check=True)
|
| 68 |
+
|
| 69 |
+
return _find_lib()
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
class TritGEMV:
|
| 73 |
+
"""Framework-agnostic trit GEMV kernel."""
|
| 74 |
+
|
| 75 |
+
def __init__(self, lib_path=None):
|
| 76 |
+
if lib_path is None:
|
| 77 |
+
lib_path = _find_lib()
|
| 78 |
+
if lib_path is None:
|
| 79 |
+
lib_path = _build_lib()
|
| 80 |
+
if lib_path is None:
|
| 81 |
+
raise RuntimeError("Cannot find or build libtrit_gemv.so")
|
| 82 |
+
|
| 83 |
+
self._lib = ctypes.CDLL(lib_path)
|
| 84 |
+
|
| 85 |
+
# Set up function signatures
|
| 86 |
+
# d2 dp4a (champion)
|
| 87 |
+
self._lib.trit_gemv_d2_dp4a.argtypes = [
|
| 88 |
+
ctypes.c_void_p, # pt (int32*)
|
| 89 |
+
ctypes.c_void_p, # ws (float*)
|
| 90 |
+
ctypes.c_void_p, # xt (int32*)
|
| 91 |
+
ctypes.c_void_p, # xs (float*)
|
| 92 |
+
ctypes.c_void_p, # y (float*)
|
| 93 |
+
ctypes.c_int, # cols
|
| 94 |
+
ctypes.c_int, # rows
|
| 95 |
+
ctypes.c_int, # num_groups
|
| 96 |
+
ctypes.c_int, # use_l2_persist
|
| 97 |
+
]
|
| 98 |
+
self._lib.trit_gemv_d2_dp4a.restype = None
|
| 99 |
+
|
| 100 |
+
# d3 native trit
|
| 101 |
+
self._lib.trit_gemv_d3_native.argtypes = [
|
| 102 |
+
ctypes.c_void_p, # pt
|
| 103 |
+
ctypes.c_void_p, # sc
|
| 104 |
+
ctypes.c_void_p, # x
|
| 105 |
+
ctypes.c_void_p, # y
|
| 106 |
+
ctypes.c_int, # cols
|
| 107 |
+
ctypes.c_int, # rows
|
| 108 |
+
ctypes.c_int, # depth
|
| 109 |
+
]
|
| 110 |
+
self._lib.trit_gemv_d3_native.restype = None
|
| 111 |
+
|
| 112 |
+
# d3 int8 dp4a (no decode, DRAM-bound path)
|
| 113 |
+
self._lib.trit_gemv_d3_int8_dp4a.argtypes = [
|
| 114 |
+
ctypes.c_void_p, # wt (int32*)
|
| 115 |
+
ctypes.c_void_p, # ws (float*)
|
| 116 |
+
ctypes.c_void_p, # xt (int32*)
|
| 117 |
+
ctypes.c_void_p, # xs (float*)
|
| 118 |
+
ctypes.c_void_p, # y (float*)
|
| 119 |
+
ctypes.c_int, # cols
|
| 120 |
+
ctypes.c_int, # rows
|
| 121 |
+
ctypes.c_int, # num_groups
|
| 122 |
+
ctypes.c_int, # use_l2_persist
|
| 123 |
+
]
|
| 124 |
+
self._lib.trit_gemv_d3_int8_dp4a.restype = None
|
| 125 |
+
|
| 126 |
+
# Utility
|
| 127 |
+
self._lib.get_l2_cache_bytes.restype = ctypes.c_int
|
| 128 |
+
self._lib.cuda_sync.restype = None
|
| 129 |
+
|
| 130 |
+
buf = ctypes.create_string_buffer(256)
|
| 131 |
+
self._lib.get_gpu_name(buf, 256)
|
| 132 |
+
self.gpu_name = buf.value.decode()
|
| 133 |
+
self.l2_bytes = self._lib.get_l2_cache_bytes()
|
| 134 |
+
|
| 135 |
+
def sync(self):
|
| 136 |
+
self._lib.cuda_sync()
|
| 137 |
+
|
| 138 |
+
def _get_ptr(self, tensor):
|
| 139 |
+
"""Extract GPU pointer from any framework's tensor."""
|
| 140 |
+
if hasattr(tensor, 'data_ptr'):
|
| 141 |
+
# PyTorch
|
| 142 |
+
return tensor.data_ptr()
|
| 143 |
+
elif hasattr(tensor, '__cuda_array_interface__'):
|
| 144 |
+
# CuPy, JAX, Numba
|
| 145 |
+
return tensor.__cuda_array_interface__['data'][0]
|
| 146 |
+
elif isinstance(tensor, int):
|
| 147 |
+
# Raw pointer
|
| 148 |
+
return tensor
|
| 149 |
+
else:
|
| 150 |
+
raise TypeError(f"Cannot extract GPU pointer from {type(tensor)}")
|
| 151 |
+
|
| 152 |
+
def gemv_d2(self, pt, ws, xt, xs, y, cols, rows, num_groups, l2_persist=True):
|
| 153 |
+
"""D2 GEMV with int4 packing + dp4a.
|
| 154 |
+
|
| 155 |
+
Args:
|
| 156 |
+
pt: int32 tensor [rows * num_groups * 8] — int4 packed weights
|
| 157 |
+
ws: float32 tensor [rows * num_groups] — weight scales
|
| 158 |
+
xt: int32 tensor [num_groups * 16] — int8 packed activations
|
| 159 |
+
xs: float32 tensor [num_groups] — activation scales
|
| 160 |
+
y: float32 tensor [rows] — output (written in-place)
|
| 161 |
+
cols: input dimension (K)
|
| 162 |
+
rows: output dimension (M)
|
| 163 |
+
num_groups: K // 64
|
| 164 |
+
l2_persist: enable L2 cache persistence (default True)
|
| 165 |
+
"""
|
| 166 |
+
self._lib.trit_gemv_d2_dp4a(
|
| 167 |
+
self._get_ptr(pt), self._get_ptr(ws),
|
| 168 |
+
self._get_ptr(xt), self._get_ptr(xs),
|
| 169 |
+
self._get_ptr(y), cols, rows, num_groups,
|
| 170 |
+
1 if l2_persist else 0,
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
def gemv_adaptive(self, pt_int4, ws, xt, xs, y, cols, rows, num_groups,
|
| 174 |
+
pt_int8=None):
|
| 175 |
+
"""Hardware-aware GEMV: auto-selects best kernel based on L2 cache.
|
| 176 |
+
|
| 177 |
+
If the int4 weight data fits in L2 → uses d2 int4 + dp4a (5x FP16)
|
| 178 |
+
If not → uses pre-expanded int8 + dp4a (2x FP16, no decode overhead)
|
| 179 |
+
|
| 180 |
+
Args:
|
| 181 |
+
pt_int4: int32 tensor — int4 packed weights (always stored, compact)
|
| 182 |
+
ws: weight scales
|
| 183 |
+
xt, xs: quantized activations
|
| 184 |
+
y: output
|
| 185 |
+
pt_int8: optional pre-expanded int8 weights for DRAM path.
|
| 186 |
+
If None and needed, expanded on-the-fly (one-time cost).
|
| 187 |
+
"""
|
| 188 |
+
weight_bytes = rows * num_groups * 8 * 4 # int4: 8 words per group
|
| 189 |
+
l2_margin = self.l2_bytes * 0.75 # leave 25% for x, scales, other data
|
| 190 |
+
|
| 191 |
+
if weight_bytes < l2_margin:
|
| 192 |
+
# Fits in L2 → use compact int4, decode inline at L2 speed
|
| 193 |
+
self._lib.trit_gemv_d2_dp4a(
|
| 194 |
+
self._get_ptr(pt_int4), self._get_ptr(ws),
|
| 195 |
+
self._get_ptr(xt), self._get_ptr(xs),
|
| 196 |
+
self._get_ptr(y), cols, rows, num_groups, 1)
|
| 197 |
+
else:
|
| 198 |
+
# Doesn't fit L2 → use int8 for zero-decode DRAM speed
|
| 199 |
+
if pt_int8 is None:
|
| 200 |
+
raise ValueError(
|
| 201 |
+
f"Layer ({weight_bytes/1e6:.0f} MB) exceeds L2 ({self.l2_bytes/1e6:.0f} MB). "
|
| 202 |
+
f"Provide pre-expanded pt_int8 for DRAM path. "
|
| 203 |
+
f"Use TritGEMV.expand_int4_to_int8(pt_int4) at model load time."
|
| 204 |
+
)
|
| 205 |
+
self._lib.trit_gemv_d3_int8_dp4a(
|
| 206 |
+
self._get_ptr(pt_int8), self._get_ptr(ws),
|
| 207 |
+
self._get_ptr(xt), self._get_ptr(xs),
|
| 208 |
+
self._get_ptr(y), cols, rows, num_groups, 0)
|
| 209 |
+
|
| 210 |
+
@staticmethod
|
| 211 |
+
def expand_int4_to_int8(pt_int4, device='cuda'):
|
| 212 |
+
"""Pre-expand int4 packed weights to int8 for DRAM-bound layers.
|
| 213 |
+
|
| 214 |
+
Called once at model load. Uses 2x more VRAM but eliminates decode overhead.
|
| 215 |
+
int4: 8 words per group → int8: 16 words per group
|
| 216 |
+
|
| 217 |
+
Args:
|
| 218 |
+
pt_int4: int32 tensor [n_groups * 8] — int4 packed
|
| 219 |
+
Returns:
|
| 220 |
+
int32 tensor [n_groups * 16] — int8 packed (dp4a compatible)
|
| 221 |
+
"""
|
| 222 |
+
import torch
|
| 223 |
+
n_words = pt_int4.shape[0]
|
| 224 |
+
n_groups = n_words // 8
|
| 225 |
+
|
| 226 |
+
# Each int4 word has 8 nibbles → 8 int8 values → 2 int8x4 words
|
| 227 |
+
pt_int8 = torch.zeros(n_groups * 16, dtype=torch.int32, device=device)
|
| 228 |
+
|
| 229 |
+
# Expand on GPU (vectorized)
|
| 230 |
+
for g in range(n_groups):
|
| 231 |
+
for w in range(8):
|
| 232 |
+
word = pt_int4[g * 8 + w].item()
|
| 233 |
+
for nib in range(8):
|
| 234 |
+
val = (word >> (nib * 4)) & 0xF
|
| 235 |
+
if val & 0x8:
|
| 236 |
+
val = val | 0xFFFFFFF0 # sign extend
|
| 237 |
+
val = val & 0xFF
|
| 238 |
+
out_col = w * 8 + nib
|
| 239 |
+
out_word = out_col // 4
|
| 240 |
+
out_byte = out_col % 4
|
| 241 |
+
pt_int8[g * 16 + out_word] |= (val << (out_byte * 8))
|
| 242 |
+
|
| 243 |
+
return pt_int8
|
| 244 |
+
|
| 245 |
+
def gemv_d3(self, pt, sc, x, y, cols, rows, depth=3):
|
| 246 |
+
"""D3 GEMV with native trit packing.
|
| 247 |
+
|
| 248 |
+
Args:
|
| 249 |
+
pt: int32 tensor [rows * ng * 13] — trit packed weights
|
| 250 |
+
sc: float32 tensor [rows * ng] — scales
|
| 251 |
+
x: float32 tensor [cols] — activations
|
| 252 |
+
y: float32 tensor [rows] — output
|
| 253 |
+
"""
|
| 254 |
+
self._lib.trit_gemv_d3_native(
|
| 255 |
+
self._get_ptr(pt), self._get_ptr(sc),
|
| 256 |
+
self._get_ptr(x), self._get_ptr(y),
|
| 257 |
+
cols, rows, depth,
|
| 258 |
+
)
|
| 259 |
+
|
| 260 |
+
def gemv_d3_int8(self, wt, ws, xt, xs, y, cols, rows, num_groups, l2_persist=True):
|
| 261 |
+
"""D3 GEMV with int8 level packing + dp4a (same quality as d3, dp4a speed).
|
| 262 |
+
|
| 263 |
+
Args:
|
| 264 |
+
wt: int32 tensor [rows * num_groups * 16] — int8 packed levels
|
| 265 |
+
ws: float32 tensor [rows * num_groups] — weight scales
|
| 266 |
+
xt: int32 tensor [num_groups * 16] — int8 packed activations
|
| 267 |
+
xs: float32 tensor [num_groups * 16] — per-word x scales
|
| 268 |
+
y: float32 tensor [rows] — output
|
| 269 |
+
"""
|
| 270 |
+
if not hasattr(self._lib, 'trit_gemv_d3_int8_dp4a'):
|
| 271 |
+
raise RuntimeError("d3 int8 not in this build — rebuild libtrit_gemv.so")
|
| 272 |
+
self._lib.trit_gemv_d3_int8_dp4a(
|
| 273 |
+
self._get_ptr(wt), self._get_ptr(ws),
|
| 274 |
+
self._get_ptr(xt), self._get_ptr(xs),
|
| 275 |
+
self._get_ptr(y), cols, rows, num_groups,
|
| 276 |
+
1 if l2_persist else 0,
|
| 277 |
+
)
|
| 278 |
+
|
| 279 |
+
def __repr__(self):
|
| 280 |
+
return f"TritGEMV(gpu='{self.gpu_name}', l2={self.l2_bytes/1e6:.0f}MB)"
|
trit_gemv_standalone.cu
ADDED
|
@@ -0,0 +1,598 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*
|
| 2 |
+
* Standalone trit GEMV kernel — no PyTorch dependency.
|
| 3 |
+
* Compiles with nvcc to a shared library (.so/.dll).
|
| 4 |
+
* Called from Python via ctypes, or from C/C++ directly.
|
| 5 |
+
*
|
| 6 |
+
* Compile:
|
| 7 |
+
* nvcc -O3 --use_fast_math -shared -Xcompiler -fPIC \
|
| 8 |
+
* -gencode=arch=compute_70,code=sm_70 \
|
| 9 |
+
* -gencode=arch=compute_75,code=sm_75 \
|
| 10 |
+
* -gencode=arch=compute_80,code=sm_80 \
|
| 11 |
+
* -gencode=arch=compute_86,code=sm_86 \
|
| 12 |
+
* -gencode=arch=compute_89,code=sm_89 \
|
| 13 |
+
* -gencode=arch=compute_90,code=sm_90 \
|
| 14 |
+
* -gencode=arch=compute_100,code=sm_100 \
|
| 15 |
+
* -gencode=arch=compute_120,code=sm_120 \
|
| 16 |
+
* -o libtrit_gemv.so trit_gemv_standalone.cu
|
| 17 |
+
*
|
| 18 |
+
* Supports: Volta(V100), Turing(2080), Ampere(3090/A100), Ada(4080/4090),
|
| 19 |
+
* Hopper(H100), Blackwell(5070/5090) — all in one binary.
|
| 20 |
+
*
|
| 21 |
+
* API: C functions with extern "C" — callable from any language.
|
| 22 |
+
*/
|
| 23 |
+
|
| 24 |
+
#include <cuda_runtime.h>
|
| 25 |
+
#include <stdint.h>
|
| 26 |
+
|
| 27 |
+
#define GROUP_SIZE 64
|
| 28 |
+
#define WARP_SIZE 32
|
| 29 |
+
#define TRIT_POS 1
|
| 30 |
+
#define TRIT_NEG 2
|
| 31 |
+
|
| 32 |
+
// Forward declarations
|
| 33 |
+
static void set_l2_persist(void* ptr, size_t bytes);
|
| 34 |
+
static void clear_l2_persist();
|
| 35 |
+
|
| 36 |
+
// ============================================================
|
| 37 |
+
// V27: D2 int4-packed + dp4a + L2 persist (champion kernel)
|
| 38 |
+
// ============================================================
|
| 39 |
+
|
| 40 |
+
// ============================================================
|
| 41 |
+
// V28: Branchless interleaved nibble decode (7 instructions for 8 weights)
|
| 42 |
+
//
|
| 43 |
+
// The trick: extract even nibbles (0,2,4,6) and odd nibbles (1,3,5,7)
|
| 44 |
+
// as separate byte vectors using mask + shift. Sign-extend all 4 bytes
|
| 45 |
+
// simultaneously with XOR + SUB (zero branches).
|
| 46 |
+
//
|
| 47 |
+
// x activations are pre-interleaved to match: x_evens has values at
|
| 48 |
+
// positions 0,2,4,6 and x_odds has 1,3,5,7. Pre-interleave is done
|
| 49 |
+
// once at activation quantization time (negligible cost).
|
| 50 |
+
//
|
| 51 |
+
// Instructions per 8 weights:
|
| 52 |
+
// v27: 32 (loop with branches)
|
| 53 |
+
// v28: 14 (7 expand + 2 dp4a + 3 load + 2 scale)
|
| 54 |
+
// Balance BW shift: 3.4 → 5.6 TB/s on A100 → crosses into memory-bound!
|
| 55 |
+
// ============================================================
|
| 56 |
+
|
| 57 |
+
#define V28_RPB 16
|
| 58 |
+
#define V28_WPG 8
|
| 59 |
+
#define V28_BS (V28_RPB * WARP_SIZE)
|
| 60 |
+
|
| 61 |
+
__global__ void k_v28(
|
| 62 |
+
const uint32_t* __restrict__ pt, // int4 packed: 8 weights per uint32
|
| 63 |
+
const float* __restrict__ ws, // weight scales [rows * ng]
|
| 64 |
+
const uint32_t* __restrict__ xt_e, // x int8 EVEN positions [ng * 8]
|
| 65 |
+
const uint32_t* __restrict__ xt_o, // x int8 ODD positions [ng * 8]
|
| 66 |
+
const float* __restrict__ xs, // x scales [ng]
|
| 67 |
+
float* __restrict__ y,
|
| 68 |
+
int cols, int rows, int num_groups
|
| 69 |
+
) {
|
| 70 |
+
int wid = threadIdx.x / WARP_SIZE;
|
| 71 |
+
int lane = threadIdx.x % WARP_SIZE;
|
| 72 |
+
int row = blockIdx.x * V28_RPB + wid;
|
| 73 |
+
if (row >= rows) return;
|
| 74 |
+
|
| 75 |
+
const uint32_t* row_w = &pt[row * num_groups * V28_WPG];
|
| 76 |
+
const float* row_ws = &ws[row * num_groups];
|
| 77 |
+
|
| 78 |
+
float acc = 0.0f;
|
| 79 |
+
int total_words = num_groups * V28_WPG;
|
| 80 |
+
|
| 81 |
+
for (int base = 0; base < total_words; base += WARP_SIZE) {
|
| 82 |
+
int w = base + lane;
|
| 83 |
+
if (w < total_words) {
|
| 84 |
+
// COALESCED load
|
| 85 |
+
uint32_t word = __ldg(&row_w[w]);
|
| 86 |
+
int g = w >> 3; // group index (shift, 1 cycle)
|
| 87 |
+
int word_in_group = w & 7; // word within group (mask, 1 cycle)
|
| 88 |
+
|
| 89 |
+
// === BRANCHLESS INT4→INT8 EXPANSION (7 instructions) ===
|
| 90 |
+
// Extract even nibbles (weights 0,2,4,6) into bytes
|
| 91 |
+
uint32_t evens = word & 0x0F0F0F0F; // AND (1 op)
|
| 92 |
+
evens = (evens ^ 0x08080808) - 0x08080808; // XOR+SUB (2 ops)
|
| 93 |
+
|
| 94 |
+
// Extract odd nibbles (weights 1,3,5,7) into bytes
|
| 95 |
+
uint32_t odds = (word >> 4) & 0x0F0F0F0F; // SHR+AND (2 ops)
|
| 96 |
+
odds = (odds ^ 0x08080808) - 0x08080808; // XOR+SUB (2 ops)
|
| 97 |
+
// Total: 7 instructions for 8 sign-extended int8 values
|
| 98 |
+
|
| 99 |
+
// dp4a against pre-interleaved x
|
| 100 |
+
// x_evens[g*8 + word_in_group] has activations at even positions
|
| 101 |
+
// x_odds[g*8 + word_in_group] has activations at odd positions
|
| 102 |
+
int x_idx = g * 8 + word_in_group;
|
| 103 |
+
uint32_t xe = __ldg(&xt_e[x_idx]);
|
| 104 |
+
uint32_t xo = __ldg(&xt_o[x_idx]);
|
| 105 |
+
|
| 106 |
+
int dp_e = __dp4a((int)evens, (int)xe, 0);
|
| 107 |
+
int dp_o = __dp4a((int)odds, (int)xo, 0);
|
| 108 |
+
|
| 109 |
+
float combined_scale = __ldg(&row_ws[g]) * __ldg(&xs[g]);
|
| 110 |
+
acc += (float)(dp_e + dp_o) * combined_scale;
|
| 111 |
+
}
|
| 112 |
+
}
|
| 113 |
+
|
| 114 |
+
#pragma unroll
|
| 115 |
+
for (int o = 16; o > 0; o >>= 1)
|
| 116 |
+
acc += __shfl_down_sync(0xFFFFFFFF, acc, o);
|
| 117 |
+
|
| 118 |
+
if (lane == 0) y[row] = acc;
|
| 119 |
+
}
|
| 120 |
+
|
| 121 |
+
// ============================================================
|
| 122 |
+
// V29: BIAS TRICK — zero sign extension, unsigned weights + correction
|
| 123 |
+
//
|
| 124 |
+
// Store weights as (level + 4) → range 0-8 (unsigned).
|
| 125 |
+
// Zero-extend nibbles: AND only, no XOR, no SUB.
|
| 126 |
+
// dp4a gives biased result. Subtract precomputed correction.
|
| 127 |
+
//
|
| 128 |
+
// Decode: 3 instructions (AND, SHR, AND) vs v28's 7
|
| 129 |
+
// Correction: 1 SUB + 1 LD per word (precomputed x_bias)
|
| 130 |
+
// Net: 10 instructions per word vs v28's 14
|
| 131 |
+
// ============================================================
|
| 132 |
+
|
| 133 |
+
#define V29_RPB 16
|
| 134 |
+
#define V29_WPG 8
|
| 135 |
+
#define V29_BS (V29_RPB * WARP_SIZE)
|
| 136 |
+
|
| 137 |
+
__global__ void k_v29(
|
| 138 |
+
const uint32_t* __restrict__ pt, // UNSIGNED int4: (level+4) packed, 0-8 per nibble
|
| 139 |
+
const float* __restrict__ ws, // weight scales [rows * ng]
|
| 140 |
+
const uint32_t* __restrict__ xt_e, // x int8 EVEN [ng * 8]
|
| 141 |
+
const uint32_t* __restrict__ xt_o, // x int8 ODD [ng * 8]
|
| 142 |
+
const int* __restrict__ x_bias, // precomputed 4×(sum of 8 x values) per word position [ng * 8]
|
| 143 |
+
const float* __restrict__ xs, // x scales [ng]
|
| 144 |
+
float* __restrict__ y,
|
| 145 |
+
int cols, int rows, int num_groups
|
| 146 |
+
) {
|
| 147 |
+
int wid = threadIdx.x / WARP_SIZE;
|
| 148 |
+
int lane = threadIdx.x % WARP_SIZE;
|
| 149 |
+
int row = blockIdx.x * V29_RPB + wid;
|
| 150 |
+
if (row >= rows) return;
|
| 151 |
+
|
| 152 |
+
const uint32_t* row_w = &pt[row * num_groups * V29_WPG];
|
| 153 |
+
const float* row_ws = &ws[row * num_groups];
|
| 154 |
+
|
| 155 |
+
float acc = 0.0f;
|
| 156 |
+
int total_words = num_groups * V29_WPG;
|
| 157 |
+
|
| 158 |
+
for (int base = 0; base < total_words; base += WARP_SIZE) {
|
| 159 |
+
int w = base + lane;
|
| 160 |
+
if (w < total_words) {
|
| 161 |
+
uint32_t word = __ldg(&row_w[w]);
|
| 162 |
+
int g = w >> 3;
|
| 163 |
+
int wig = w & 7;
|
| 164 |
+
|
| 165 |
+
// BIAS DECODE: 3 instructions total (no XOR, no SUB)
|
| 166 |
+
uint32_t evens = word & 0x0F0F0F0F; // AND
|
| 167 |
+
uint32_t odds = (word >> 4) & 0x0F0F0F0F; // SHR + AND
|
| 168 |
+
// Values 0-8 are valid positive int8 — no sign extension needed
|
| 169 |
+
|
| 170 |
+
int x_idx = g * 8 + wig;
|
| 171 |
+
uint32_t xe = __ldg(&xt_e[x_idx]);
|
| 172 |
+
uint32_t xo = __ldg(&xt_o[x_idx]);
|
| 173 |
+
|
| 174 |
+
// dp4a: biased result (includes +4 per weight)
|
| 175 |
+
int dp = __dp4a((int)evens, (int)xe, 0)
|
| 176 |
+
+ __dp4a((int)odds, (int)xo, 0);
|
| 177 |
+
|
| 178 |
+
// Subtract precomputed bias: 4 × sum of 8 x values for this word
|
| 179 |
+
int bias = __ldg(&x_bias[x_idx]);
|
| 180 |
+
dp -= bias;
|
| 181 |
+
|
| 182 |
+
float combined_scale = __ldg(&row_ws[g]) * __ldg(&xs[g]);
|
| 183 |
+
acc += (float)dp * combined_scale;
|
| 184 |
+
}
|
| 185 |
+
}
|
| 186 |
+
|
| 187 |
+
#pragma unroll
|
| 188 |
+
for (int o = 16; o > 0; o >>= 1)
|
| 189 |
+
acc += __shfl_down_sync(0xFFFFFFFF, acc, o);
|
| 190 |
+
|
| 191 |
+
if (lane == 0) y[row] = acc;
|
| 192 |
+
}
|
| 193 |
+
|
| 194 |
+
// v29 wrapper moved to extern "C" block below
|
| 195 |
+
|
| 196 |
+
// Keep v27 as fallback (doesn't need interleaved x)
|
| 197 |
+
__device__ __forceinline__ uint32_t extract_int4x4_to_int8x4(uint32_t word, int start) {
|
| 198 |
+
uint32_t result = 0;
|
| 199 |
+
#pragma unroll
|
| 200 |
+
for (int i = 0; i < 4; i++) {
|
| 201 |
+
int shift = (start + i) * 4;
|
| 202 |
+
int nibble = (word >> shift) & 0xF;
|
| 203 |
+
int val = (nibble & 0x8) ? (nibble | 0xFFFFFFF0) : nibble;
|
| 204 |
+
result |= ((uint32_t)(val & 0xFF)) << (i * 8);
|
| 205 |
+
}
|
| 206 |
+
return result;
|
| 207 |
+
}
|
| 208 |
+
|
| 209 |
+
#define V27_RPB 16
|
| 210 |
+
#define V27_WPG 8
|
| 211 |
+
#define V27_BS (V27_RPB * WARP_SIZE)
|
| 212 |
+
|
| 213 |
+
__global__ void k_v27(
|
| 214 |
+
const uint32_t* __restrict__ pt,
|
| 215 |
+
const float* __restrict__ ws,
|
| 216 |
+
const uint32_t* __restrict__ xt,
|
| 217 |
+
const float* __restrict__ xs,
|
| 218 |
+
float* __restrict__ y,
|
| 219 |
+
int cols, int rows, int num_groups
|
| 220 |
+
) {
|
| 221 |
+
int wid = threadIdx.x / WARP_SIZE;
|
| 222 |
+
int lane = threadIdx.x % WARP_SIZE;
|
| 223 |
+
int row = blockIdx.x * V28_RPB + wid;
|
| 224 |
+
if (row >= rows) return;
|
| 225 |
+
|
| 226 |
+
const uint32_t* row_w = &pt[row * num_groups * V27_WPG];
|
| 227 |
+
const float* row_ws = &ws[row * num_groups];
|
| 228 |
+
|
| 229 |
+
float acc = 0.0f;
|
| 230 |
+
int total_words = num_groups * V27_WPG;
|
| 231 |
+
|
| 232 |
+
for (int base = 0; base < total_words; base += WARP_SIZE) {
|
| 233 |
+
int w = base + lane;
|
| 234 |
+
if (w < total_words) {
|
| 235 |
+
uint32_t word = __ldg(&row_w[w]);
|
| 236 |
+
int g = w >> 3;
|
| 237 |
+
int word_in_group = w & 7;
|
| 238 |
+
|
| 239 |
+
uint32_t lo = extract_int4x4_to_int8x4(word, 0);
|
| 240 |
+
uint32_t hi = extract_int4x4_to_int8x4(word, 4);
|
| 241 |
+
|
| 242 |
+
int x_base = g * 16 + word_in_group * 2;
|
| 243 |
+
uint32_t x_lo = __ldg(&xt[x_base]);
|
| 244 |
+
uint32_t x_hi = __ldg(&xt[x_base + 1]);
|
| 245 |
+
|
| 246 |
+
int dp_lo = __dp4a((int)lo, (int)x_lo, 0);
|
| 247 |
+
int dp_hi = __dp4a((int)hi, (int)x_hi, 0);
|
| 248 |
+
|
| 249 |
+
float combined_scale = __ldg(&row_ws[g]) * __ldg(&xs[g]);
|
| 250 |
+
acc += (float)(dp_lo + dp_hi) * combined_scale;
|
| 251 |
+
}
|
| 252 |
+
}
|
| 253 |
+
|
| 254 |
+
#pragma unroll
|
| 255 |
+
for (int o = 16; o > 0; o >>= 1)
|
| 256 |
+
acc += __shfl_down_sync(0xFFFFFFFF, acc, o);
|
| 257 |
+
|
| 258 |
+
if (lane == 0) y[row] = acc;
|
| 259 |
+
}
|
| 260 |
+
|
| 261 |
+
// ============================================================
|
| 262 |
+
// V9-style: trit-packed d3 (for models stored in trit format)
|
| 263 |
+
// ============================================================
|
| 264 |
+
|
| 265 |
+
#define V9R 4
|
| 266 |
+
#define V9W 2
|
| 267 |
+
#define V9BS (V9R * V9W * WARP_SIZE)
|
| 268 |
+
|
| 269 |
+
__device__ __forceinline__ float mac_wide_d3(
|
| 270 |
+
const uint32_t* __restrict__ p, const float* __restrict__ x, int tid
|
| 271 |
+
) {
|
| 272 |
+
float acc = 0.0f;
|
| 273 |
+
#pragma unroll
|
| 274 |
+
for (int i = 0; i < 4; i++) {
|
| 275 |
+
int idx = tid * 4 + i;
|
| 276 |
+
int w = idx / 5, pos = idx % 5;
|
| 277 |
+
uint32_t bits = (__ldg(&p[w]) >> (pos * 6)) & 0x3F;
|
| 278 |
+
int t0 = bits & 3, t1 = (bits >> 2) & 3, t2 = (bits >> 4) & 3;
|
| 279 |
+
int lv = ((t2==TRIT_POS)-(t2==TRIT_NEG))*9 + ((t1==TRIT_POS)-(t1==TRIT_NEG))*3
|
| 280 |
+
+ ((t0==TRIT_POS)-(t0==TRIT_NEG));
|
| 281 |
+
acc += lv * __ldg(&x[idx]);
|
| 282 |
+
}
|
| 283 |
+
return acc;
|
| 284 |
+
}
|
| 285 |
+
|
| 286 |
+
__global__ void k_v9(
|
| 287 |
+
const uint32_t* __restrict__ pt, const float* __restrict__ sc,
|
| 288 |
+
const float* __restrict__ x, float* __restrict__ y,
|
| 289 |
+
int in_f, int out_f, int depth
|
| 290 |
+
) {
|
| 291 |
+
__shared__ float parts[V9R * V9W];
|
| 292 |
+
int base = blockIdx.x * V9R;
|
| 293 |
+
int wid = threadIdx.x / WARP_SIZE, lane = threadIdx.x % WARP_SIZE;
|
| 294 |
+
int lr = wid / V9W, rw = wid % V9W, row = base + lr;
|
| 295 |
+
int ng = in_f / GROUP_SIZE;
|
| 296 |
+
const int w = 13;
|
| 297 |
+
|
| 298 |
+
int half = lane / 16;
|
| 299 |
+
int tid_in_group = lane % 16;
|
| 300 |
+
|
| 301 |
+
float partial = 0.0f;
|
| 302 |
+
if (row < out_f) {
|
| 303 |
+
for (int g_pair = rw; g_pair < (ng + 1) / 2; g_pair += V9W) {
|
| 304 |
+
int g = g_pair * 2 + half;
|
| 305 |
+
if (g < ng) {
|
| 306 |
+
float ga = mac_wide_d3(&pt[(row * ng + g) * w],
|
| 307 |
+
&x[g * GROUP_SIZE], tid_in_group);
|
| 308 |
+
unsigned mask = half ? 0xFFFF0000u : 0x0000FFFFu;
|
| 309 |
+
#pragma unroll
|
| 310 |
+
for (int o = 8; o > 0; o >>= 1)
|
| 311 |
+
ga += __shfl_down_sync(mask, ga, o);
|
| 312 |
+
if (tid_in_group == 0)
|
| 313 |
+
partial += ga * __ldg(&sc[row * ng + g]);
|
| 314 |
+
}
|
| 315 |
+
}
|
| 316 |
+
}
|
| 317 |
+
|
| 318 |
+
float my_partial = (lane == 0 || lane == 16) ? partial : 0.0f;
|
| 319 |
+
my_partial += __shfl_xor_sync(0xFFFFFFFF, my_partial, 16);
|
| 320 |
+
if (lane == 0) parts[wid] = my_partial;
|
| 321 |
+
__syncthreads();
|
| 322 |
+
if (lane == 0 && rw == 0 && row < out_f) {
|
| 323 |
+
float s = 0;
|
| 324 |
+
for (int i = 0; i < V9W; i++) s += parts[lr * V9W + i];
|
| 325 |
+
y[row] = s;
|
| 326 |
+
}
|
| 327 |
+
}
|
| 328 |
+
|
| 329 |
+
// ============================================================
|
| 330 |
+
// L2 persistence helpers
|
| 331 |
+
// ============================================================
|
| 332 |
+
|
| 333 |
+
static void set_l2_persist(void* ptr, size_t bytes) {
|
| 334 |
+
cudaStreamAttrValue attr;
|
| 335 |
+
attr.accessPolicyWindow.base_ptr = ptr;
|
| 336 |
+
attr.accessPolicyWindow.num_bytes = bytes;
|
| 337 |
+
attr.accessPolicyWindow.hitRatio = 1.0f;
|
| 338 |
+
attr.accessPolicyWindow.hitProp = cudaAccessPropertyPersisting;
|
| 339 |
+
attr.accessPolicyWindow.missProp = cudaAccessPropertyStreaming;
|
| 340 |
+
cudaStreamSetAttribute(0, cudaStreamAttributeAccessPolicyWindow, &attr);
|
| 341 |
+
}
|
| 342 |
+
|
| 343 |
+
static void clear_l2_persist() {
|
| 344 |
+
cudaStreamAttrValue attr;
|
| 345 |
+
memset(&attr, 0, sizeof(attr));
|
| 346 |
+
cudaStreamSetAttribute(0, cudaStreamAttributeAccessPolicyWindow, &attr);
|
| 347 |
+
}
|
| 348 |
+
|
| 349 |
+
// ============================================================
|
| 350 |
+
// C API — callable from any language via dlopen/ctypes/FFI
|
| 351 |
+
// ============================================================
|
| 352 |
+
|
| 353 |
+
extern "C" {
|
| 354 |
+
|
| 355 |
+
// v27: d2 int4-packed + dp4a (champion for GPU)
|
| 356 |
+
// pt: [rows * ng * 8] int32 (int4 packed weights)
|
| 357 |
+
// ws: [rows * ng] float32 (weight scales)
|
| 358 |
+
// xt: [ng * 16] int32 (int8 packed activations)
|
| 359 |
+
// xs: [ng] float32 (activation scales)
|
| 360 |
+
// y: [rows] float32 (output)
|
| 361 |
+
void trit_gemv_d2_dp4a(
|
| 362 |
+
const int32_t* pt, const float* ws,
|
| 363 |
+
const int32_t* xt, const float* xs,
|
| 364 |
+
float* y, int cols, int rows, int num_groups,
|
| 365 |
+
int use_l2_persist
|
| 366 |
+
) {
|
| 367 |
+
if (use_l2_persist) {
|
| 368 |
+
set_l2_persist((void*)pt, (size_t)rows * num_groups * 8 * sizeof(int32_t));
|
| 369 |
+
}
|
| 370 |
+
k_v27<<<(rows + V27_RPB - 1) / V27_RPB, V27_BS>>>(
|
| 371 |
+
(const uint32_t*)pt, ws, (const uint32_t*)xt, xs, y, cols, rows, num_groups);
|
| 372 |
+
if (use_l2_persist) {
|
| 373 |
+
clear_l2_persist();
|
| 374 |
+
}
|
| 375 |
+
}
|
| 376 |
+
|
| 377 |
+
// v9: trit-packed d3 (for native trit format)
|
| 378 |
+
// pt: [rows * ng * 13] int32 (trit packed weights)
|
| 379 |
+
// sc: [rows * ng] float32 (scales)
|
| 380 |
+
// x: [cols] float32 (activations)
|
| 381 |
+
// y: [rows] float32 (output)
|
| 382 |
+
void trit_gemv_d3_native(
|
| 383 |
+
const int32_t* pt, const float* sc,
|
| 384 |
+
const float* x, float* y,
|
| 385 |
+
int cols, int rows, int depth
|
| 386 |
+
) {
|
| 387 |
+
k_v9<<<(rows + V9R - 1) / V9R, V9BS>>>(
|
| 388 |
+
(const uint32_t*)pt, sc, x, y, cols, rows, depth);
|
| 389 |
+
}
|
| 390 |
+
|
| 391 |
+
// v29: d2 unsigned int4 + bias trick (no sign extension)
|
| 392 |
+
void trit_gemv_d2_bias(
|
| 393 |
+
const int32_t* pt, const float* ws,
|
| 394 |
+
const int32_t* xt_e, const int32_t* xt_o,
|
| 395 |
+
const int32_t* x_bias, const float* xs,
|
| 396 |
+
float* y, int cols, int rows, int num_groups,
|
| 397 |
+
int use_l2_persist
|
| 398 |
+
) {
|
| 399 |
+
if (use_l2_persist) {
|
| 400 |
+
set_l2_persist((void*)pt, (size_t)rows * num_groups * 8 * sizeof(int32_t));
|
| 401 |
+
}
|
| 402 |
+
k_v29<<<(rows + V29_RPB - 1) / V29_RPB, V29_BS>>>(
|
| 403 |
+
(const uint32_t*)pt, ws,
|
| 404 |
+
(const uint32_t*)xt_e, (const uint32_t*)xt_o,
|
| 405 |
+
(const int*)x_bias, xs,
|
| 406 |
+
y, cols, rows, num_groups);
|
| 407 |
+
if (use_l2_persist) {
|
| 408 |
+
clear_l2_persist();
|
| 409 |
+
}
|
| 410 |
+
}
|
| 411 |
+
|
| 412 |
+
// v28: d2 int4 + branchless interleaved decode + dp4a
|
| 413 |
+
// xt_e/xt_o: pre-interleaved x (even/odd nibble positions)
|
| 414 |
+
// Each has ng*8 uint32 words (4 int8 values per word, 8 words per group)
|
| 415 |
+
void trit_gemv_d2_fast(
|
| 416 |
+
const int32_t* pt, const float* ws,
|
| 417 |
+
const int32_t* xt_e, const int32_t* xt_o, const float* xs,
|
| 418 |
+
float* y, int cols, int rows, int num_groups,
|
| 419 |
+
int use_l2_persist
|
| 420 |
+
) {
|
| 421 |
+
if (use_l2_persist) {
|
| 422 |
+
set_l2_persist((void*)pt, (size_t)rows * num_groups * 8 * sizeof(int32_t));
|
| 423 |
+
}
|
| 424 |
+
k_v28<<<(rows + V28_RPB - 1) / V28_RPB, V28_BS>>>(
|
| 425 |
+
(const uint32_t*)pt, ws,
|
| 426 |
+
(const uint32_t*)xt_e, (const uint32_t*)xt_o, xs,
|
| 427 |
+
y, cols, rows, num_groups);
|
| 428 |
+
if (use_l2_persist) {
|
| 429 |
+
clear_l2_persist();
|
| 430 |
+
}
|
| 431 |
+
}
|
| 432 |
+
|
| 433 |
+
// v21f: d3 int8-packed + dp4a (same format as v21f in wrapper)
|
| 434 |
+
// wt: [rows * ng * 16] int32 (int8 packed weight levels)
|
| 435 |
+
// ws: [rows * ng] float32 (weight scales)
|
| 436 |
+
// xt: [ng * 16] int32 (int8 packed activations)
|
| 437 |
+
// xs: [ng] float32 (activation scales)
|
| 438 |
+
#define V21F_RPB 4
|
| 439 |
+
#define V21F_BS (V21F_RPB * WARP_SIZE)
|
| 440 |
+
|
| 441 |
+
__global__ void k_v21f_standalone(
|
| 442 |
+
const uint32_t* __restrict__ wt,
|
| 443 |
+
const float* __restrict__ ws,
|
| 444 |
+
const uint32_t* __restrict__ xt,
|
| 445 |
+
const float* __restrict__ xs,
|
| 446 |
+
float* __restrict__ y,
|
| 447 |
+
int cols, int rows, int num_groups
|
| 448 |
+
) {
|
| 449 |
+
int wid = threadIdx.x / WARP_SIZE;
|
| 450 |
+
int lane = threadIdx.x % WARP_SIZE;
|
| 451 |
+
int row = blockIdx.x * V21F_RPB + wid;
|
| 452 |
+
if (row >= rows) return;
|
| 453 |
+
|
| 454 |
+
const uint32_t* row_w = &wt[row * num_groups * 16];
|
| 455 |
+
const float* row_ws = &ws[row * num_groups];
|
| 456 |
+
float acc = 0.0f;
|
| 457 |
+
int total_words = num_groups * 16;
|
| 458 |
+
|
| 459 |
+
for (int base = 0; base < total_words; base += WARP_SIZE) {
|
| 460 |
+
int w = base + lane;
|
| 461 |
+
if (w < total_words) {
|
| 462 |
+
uint32_t w_word = __ldg(&row_w[w]);
|
| 463 |
+
uint32_t x_word = __ldg(&xt[w]);
|
| 464 |
+
int dp = __dp4a((int)w_word, (int)x_word, 0);
|
| 465 |
+
int g = w >> 4; // 16 words per group
|
| 466 |
+
acc += (float)dp * __ldg(&row_ws[g]) * __ldg(&xs[g]);
|
| 467 |
+
}
|
| 468 |
+
}
|
| 469 |
+
|
| 470 |
+
#pragma unroll
|
| 471 |
+
for (int o = 16; o > 0; o >>= 1)
|
| 472 |
+
acc += __shfl_down_sync(0xFFFFFFFF, acc, o);
|
| 473 |
+
|
| 474 |
+
if (lane == 0) y[row] = acc;
|
| 475 |
+
}
|
| 476 |
+
|
| 477 |
+
// ============================================================
|
| 478 |
+
// D3 HARDENED: int8 dp4a, 16 RPB, L2 persist, deferred reduction
|
| 479 |
+
//
|
| 480 |
+
// d3 levels: -13 to +13 (27 values), stored as int8 (1 byte each)
|
| 481 |
+
// 16 words per group, 4 int8 values per word = 64 values/group
|
| 482 |
+
// Division by 16 = shift (no div-by-13 problem!)
|
| 483 |
+
//
|
| 484 |
+
// This is the SIMPLEST kernel — no decode at all.
|
| 485 |
+
// The int8 values go DIRECTLY into dp4a.
|
| 486 |
+
// Pure memory-bound on every GPU.
|
| 487 |
+
// ============================================================
|
| 488 |
+
|
| 489 |
+
#define D3H_RPB 16
|
| 490 |
+
#define D3H_WPG 16 // 16 uint32 words per group (4 int8 each = 64 values)
|
| 491 |
+
#define D3H_BS (D3H_RPB * WARP_SIZE)
|
| 492 |
+
|
| 493 |
+
__global__ void k_d3_hardened(
|
| 494 |
+
const uint32_t* __restrict__ wt, // int8 packed: 16 words per group
|
| 495 |
+
const float* __restrict__ ws, // weight scales [rows * ng]
|
| 496 |
+
const uint32_t* __restrict__ xt, // x int8 packed: 16 words per group
|
| 497 |
+
const float* __restrict__ xs, // x scales [ng]
|
| 498 |
+
float* __restrict__ y,
|
| 499 |
+
int cols, int rows, int num_groups
|
| 500 |
+
) {
|
| 501 |
+
int wid = threadIdx.x / WARP_SIZE;
|
| 502 |
+
int lane = threadIdx.x % WARP_SIZE;
|
| 503 |
+
int row = blockIdx.x * D3H_RPB + wid;
|
| 504 |
+
if (row >= rows) return;
|
| 505 |
+
|
| 506 |
+
const uint32_t* row_w = &wt[row * num_groups * D3H_WPG];
|
| 507 |
+
const float* row_ws = &ws[row * num_groups];
|
| 508 |
+
|
| 509 |
+
float acc = 0.0f;
|
| 510 |
+
int total_words = num_groups * D3H_WPG;
|
| 511 |
+
|
| 512 |
+
for (int base = 0; base < total_words; base += WARP_SIZE) {
|
| 513 |
+
int w = base + lane;
|
| 514 |
+
if (w < total_words) {
|
| 515 |
+
// COALESCED load — 32 threads × 4 bytes = 128 bytes
|
| 516 |
+
uint32_t w_word = __ldg(&row_w[w]);
|
| 517 |
+
uint32_t x_word = __ldg(&xt[w]);
|
| 518 |
+
|
| 519 |
+
// dp4a: 4× int8 multiply-accumulate — ZERO decode
|
| 520 |
+
int dp = __dp4a((int)w_word, (int)x_word, 0);
|
| 521 |
+
|
| 522 |
+
// Group index: SHIFT (16 is power of 2!)
|
| 523 |
+
int g = w >> 4;
|
| 524 |
+
|
| 525 |
+
// Deferred: accumulate per-thread, ONE reduction at end
|
| 526 |
+
acc += (float)dp * __ldg(&row_ws[g]) * __ldg(&xs[g]);
|
| 527 |
+
}
|
| 528 |
+
}
|
| 529 |
+
|
| 530 |
+
// ONE warp reduction
|
| 531 |
+
#pragma unroll
|
| 532 |
+
for (int o = 16; o > 0; o >>= 1)
|
| 533 |
+
acc += __shfl_down_sync(0xFFFFFFFF, acc, o);
|
| 534 |
+
|
| 535 |
+
if (lane == 0) y[row] = acc;
|
| 536 |
+
}
|
| 537 |
+
|
| 538 |
+
// d3 hardened: uses k_d3_hardened (16 RPB, deferred reduction)
|
| 539 |
+
void trit_gemv_d3_int8_dp4a(
|
| 540 |
+
const int32_t* wt, const float* ws,
|
| 541 |
+
const int32_t* xt, const float* xs,
|
| 542 |
+
float* y, int cols, int rows, int num_groups,
|
| 543 |
+
int use_l2_persist
|
| 544 |
+
) {
|
| 545 |
+
if (use_l2_persist) {
|
| 546 |
+
set_l2_persist((void*)wt, (size_t)rows * num_groups * 16 * sizeof(int32_t));
|
| 547 |
+
}
|
| 548 |
+
k_d3_hardened<<<(rows + D3H_RPB - 1) / D3H_RPB, D3H_BS>>>(
|
| 549 |
+
(const uint32_t*)wt, ws, (const uint32_t*)xt, xs, y, cols, rows, num_groups);
|
| 550 |
+
if (use_l2_persist) {
|
| 551 |
+
clear_l2_persist();
|
| 552 |
+
}
|
| 553 |
+
}
|
| 554 |
+
|
| 555 |
+
// Run the same layer N times back-to-back to measure pipeline / L2 reuse benefit
|
| 556 |
+
void trit_gemv_pipeline_bench(
|
| 557 |
+
const int32_t* pt, const float* ws,
|
| 558 |
+
const int32_t* xt_e, const int32_t* xt_o, const float* xs,
|
| 559 |
+
float* y, int cols, int rows, int num_groups,
|
| 560 |
+
int n_repeats, int use_l2_persist
|
| 561 |
+
) {
|
| 562 |
+
if (use_l2_persist) {
|
| 563 |
+
set_l2_persist((void*)pt, (size_t)rows * num_groups * 8 * sizeof(int32_t));
|
| 564 |
+
}
|
| 565 |
+
// Launch n_repeats sequential v28 kernels in the SAME stream
|
| 566 |
+
// This measures the pipeline benefit: back-to-back launches share L2
|
| 567 |
+
for (int i = 0; i < n_repeats; i++) {
|
| 568 |
+
k_v28<<<(rows + V28_RPB - 1) / V28_RPB, V28_BS>>>(
|
| 569 |
+
(const uint32_t*)pt, ws,
|
| 570 |
+
(const uint32_t*)xt_e, (const uint32_t*)xt_o, xs,
|
| 571 |
+
y, cols, rows, num_groups);
|
| 572 |
+
}
|
| 573 |
+
if (use_l2_persist) {
|
| 574 |
+
clear_l2_persist();
|
| 575 |
+
}
|
| 576 |
+
}
|
| 577 |
+
|
| 578 |
+
// Query L2 cache size (for deciding whether to use L2 persist)
|
| 579 |
+
int get_l2_cache_bytes() {
|
| 580 |
+
cudaDeviceProp prop;
|
| 581 |
+
cudaGetDeviceProperties(&prop, 0);
|
| 582 |
+
return prop.l2CacheSize;
|
| 583 |
+
}
|
| 584 |
+
|
| 585 |
+
// Query GPU name
|
| 586 |
+
void get_gpu_name(char* buf, int buflen) {
|
| 587 |
+
cudaDeviceProp prop;
|
| 588 |
+
cudaGetDeviceProperties(&prop, 0);
|
| 589 |
+
strncpy(buf, prop.name, buflen - 1);
|
| 590 |
+
buf[buflen - 1] = '\0';
|
| 591 |
+
}
|
| 592 |
+
|
| 593 |
+
// Synchronize (for timing from Python)
|
| 594 |
+
void cuda_sync() {
|
| 595 |
+
cudaDeviceSynchronize();
|
| 596 |
+
}
|
| 597 |
+
|
| 598 |
+
} // extern "C"
|