tritllm-kernel / trit_gemv.cu
Entrit's picture
fix: address codex review BLOCKERs and SHOULD-FIXes; update KNOWN_ISSUES
7c251e6 verified
/*
* TritLLM CUDA Kernel — Ternary GEMV (Matrix-Vector Multiply)
*
* Core operation: y = W_ternary @ x
* Where W_ternary is packed ternary weights with per-group scales.
*
* Each group of 64 weights has:
* - A depth (1-4 trits per weight)
* - A FP16 scale factor
* - Packed trit values (2 bits per trit: 00=0, 01=+1, 10=-1, 11=unused)
*
* The key: NO floating-point multiply in the inner loop.
* Ternary MAC = conditional add/subtract.
*/
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <stdint.h>
#define GROUP_SIZE 64
#define WARP_SIZE 32
// Trit encoding: 2 bits per trit
// 00 = 0, 01 = +1, 10 = -1
#define TRIT_ZERO 0
#define TRIT_POS 1
#define TRIT_NEG 2
/*
* Depth 1 (3 levels: {-1, 0, +1}): 1 trit per weight, 2 bits per weight
* Pack 16 trits per uint32 (16 * 2 = 32 bits)
* Group of 64 = 4 uint32s
*
* Inner loop: read trit, branch-free conditional accumulate
*/
__device__ __forceinline__ float trit_mac_d1(
const uint32_t* __restrict__ packed, // 4 uint32s = 64 trits
const float* __restrict__ x, // 64 activations
int lane // warp lane (0-31)
) {
float acc = 0.0f;
// Each thread in warp handles 2 elements (64 / 32 = 2)
#pragma unroll
for (int i = 0; i < 2; i++) {
int idx = lane * 2 + i;
int word = idx / 16; // which uint32 (0-3)
int bit_offset = (idx % 16) * 2; // bit position within word
uint32_t trit = (packed[word] >> bit_offset) & 0x3;
float val = x[idx];
// Branch-free: acc += (trit == 1) * val - (trit == 2) * val
acc += ((trit == TRIT_POS) - (trit == TRIT_NEG)) * val;
}
return acc;
}
/*
* Depth 2 (9 levels: {-4..+4}): 2 trits per weight, 4 bits per weight
* Trit value = trit1 * 3 + trit0 - 4 (maps to -4..+4)
* Pack 8 values per uint32 (8 * 4 = 32 bits)
* Group of 64 = 8 uint32s
*/
__device__ __forceinline__ float trit_mac_d2(
const uint32_t* __restrict__ packed, // 8 uint32s = 64 values
const float* __restrict__ x,
int lane
) {
float acc = 0.0f;
#pragma unroll
for (int i = 0; i < 2; i++) {
int idx = lane * 2 + i;
int word = idx / 8;
int bit_offset = (idx % 8) * 4;
uint32_t bits = (packed[word] >> bit_offset) & 0xF;
// Decode: trit1 = bits >> 2, trit0 = bits & 0x3
// value = (trit1_sign * 3 + trit0_sign)
// where trit_sign: 00->0, 01->+1, 10->-1
int t0 = (int)(bits & 0x3);
int t1 = (int)((bits >> 2) & 0x3);
int sign0 = (t0 == TRIT_POS) - (t0 == TRIT_NEG);
int sign1 = (t1 == TRIT_POS) - (t1 == TRIT_NEG);
int level = sign1 * 3 + sign0; // -4 to +4
// Still no FP multiply — integer * float is one instruction
// level is small integer, compiler optimizes to repeated add
acc += level * x[idx];
}
return acc;
}
/*
* Depth 3 (27 levels: {-13..+13}): 3 trits per weight, 6 bits per weight
* Pack 5 values per uint32 (5 * 6 = 30 bits, 2 wasted)
* Group of 64 = 13 uint32s (64 values, last uint32 has 4 values)
*/
__device__ __forceinline__ float trit_mac_d3(
const uint32_t* __restrict__ packed, // 13 uint32s
const float* __restrict__ x,
int lane
) {
float acc = 0.0f;
#pragma unroll
for (int i = 0; i < 2; i++) {
int idx = lane * 2 + i;
int word = idx / 5;
int pos = idx % 5;
int bit_offset = pos * 6;
uint32_t bits = (packed[word] >> bit_offset) & 0x3F;
int t0 = (int)(bits & 0x3);
int t1 = (int)((bits >> 2) & 0x3);
int t2 = (int)((bits >> 4) & 0x3);
int s0 = (t0 == TRIT_POS) - (t0 == TRIT_NEG);
int s1 = (t1 == TRIT_POS) - (t1 == TRIT_NEG);
int s2 = (t2 == TRIT_POS) - (t2 == TRIT_NEG);
int level = s2 * 9 + s1 * 3 + s0; // -13 to +13
acc += level * x[idx];
}
return acc;
}
/*
* Depth 4 (81 levels: {-40..+40}): 4 trits per weight, 8 bits per weight
* Pack 4 values per uint32 (4 * 8 = 32 bits, perfect)
* Group of 64 = 16 uint32s
*/
__device__ __forceinline__ float trit_mac_d4(
const uint32_t* __restrict__ packed, // 16 uint32s
const float* __restrict__ x,
int lane
) {
float acc = 0.0f;
#pragma unroll
for (int i = 0; i < 2; i++) {
int idx = lane * 2 + i;
int word = idx / 4;
int bit_offset = (idx % 4) * 8;
uint32_t bits = (packed[word] >> bit_offset) & 0xFF;
int t0 = (int)(bits & 0x3);
int t1 = (int)((bits >> 2) & 0x3);
int t2 = (int)((bits >> 4) & 0x3);
int t3 = (int)((bits >> 6) & 0x3);
int s0 = (t0 == TRIT_POS) - (t0 == TRIT_NEG);
int s1 = (t1 == TRIT_POS) - (t1 == TRIT_NEG);
int s2 = (t2 == TRIT_POS) - (t2 == TRIT_NEG);
int s3 = (t3 == TRIT_POS) - (t3 == TRIT_NEG);
int level = s3 * 27 + s2 * 9 + s1 * 3 + s0;
acc += level * x[idx];
}
return acc;
}
/*
* Main GEMV kernel: y[out_features] = W[out_features, in_features] @ x[in_features]
*
* W is stored as packed ternary groups:
* - packed_trits: variable-length packed trit data per group
* - scales: FP16 scale per group
* - depths: uint8 depth per group (1-4)
* - group_offsets: byte offset into packed_trits for each group
*
* One warp per output row, iterating over groups along the input dimension.
* Warp reduction gives the final dot product.
*/
// Simplified version: uniform depth across all groups in a tensor
// (variable-depth version below)
// Launch contract: blockDim.x == 32 (one warp per block), in_features % 64 == 0.
// The kernel uses lane = threadIdx.x and a full-warp shuffle mask, so larger
// blocks would alias the lane index and race on y[row]. Trailing partial groups
// are an unsupported shape, not silently dropped.
__global__ void trit_gemv_uniform(
const uint32_t* __restrict__ packed_trits, // packed trit data
const float* __restrict__ scales, // [num_groups] FP16 stored as float
const float* __restrict__ x, // [in_features]
float* __restrict__ y, // [out_features]
int in_features,
int out_features,
int depth // uniform depth 1-4
) {
if (blockDim.x != WARP_SIZE) return; // launch contract: 1 warp/block
if (in_features % GROUP_SIZE) return; // launch contract: K mod 64 == 0
int row = blockIdx.x; // one block per output row
if (row >= out_features) return;
int lane = threadIdx.x; // lane within warp (0-31)
int num_groups = in_features / GROUP_SIZE;
// Words per group depends on depth
int words_per_group;
switch (depth) {
case 1: words_per_group = 4; break; // 64 * 2 / 32
case 2: words_per_group = 8; break; // 64 * 4 / 32
case 3: words_per_group = 13; break; // ceil(64 * 6 / 32)
case 4: words_per_group = 16; break; // 64 * 8 / 32
default: words_per_group = 4; break;
}
float row_acc = 0.0f;
for (int g = 0; g < num_groups; g++) {
int group_offset = (row * num_groups + g) * words_per_group;
const uint32_t* group_data = &packed_trits[group_offset];
const float* group_x = &x[g * GROUP_SIZE];
float scale = scales[row * num_groups + g];
float group_acc;
switch (depth) {
case 1: group_acc = trit_mac_d1(group_data, group_x, lane); break;
case 2: group_acc = trit_mac_d2(group_data, group_x, lane); break;
case 3: group_acc = trit_mac_d3(group_data, group_x, lane); break;
case 4: group_acc = trit_mac_d4(group_data, group_x, lane); break;
default: group_acc = 0.0f; break;
}
// Warp reduction
#pragma unroll
for (int offset = 16; offset > 0; offset >>= 1) {
group_acc += __shfl_down_sync(0xFFFFFFFF, group_acc, offset);
}
// Lane 0 accumulates the scaled result
if (lane == 0) {
row_acc += group_acc * scale;
}
}
// Write output
if (lane == 0) {
y[row] = row_acc;
}
}
/*
* Variable-depth version: each group can have a different depth.
* Uses a depth map and offset table to handle mixed-depth tensors.
*/
// Launch contract: blockDim.x == 32 (one warp per block), in_features % 64 == 0.
__global__ void trit_gemv_variable(
const uint32_t* __restrict__ packed_trits,
const float* __restrict__ scales,
const uint8_t* __restrict__ depth_map, // [num_groups_per_row] depth per group
const int* __restrict__ group_offsets, // [num_groups_per_row + 1] word offsets
const float* __restrict__ x,
float* __restrict__ y,
int in_features,
int out_features
) {
if (blockDim.x != WARP_SIZE) return;
if (in_features % GROUP_SIZE) return;
int row = blockIdx.x;
if (row >= out_features) return;
int lane = threadIdx.x;
int num_groups = in_features / GROUP_SIZE;
float row_acc = 0.0f;
for (int g = 0; g < num_groups; g++) {
int depth = depth_map[g];
int word_offset = group_offsets[g] + row * group_offsets[num_groups]; // row stride
const uint32_t* group_data = &packed_trits[word_offset];
const float* group_x = &x[g * GROUP_SIZE];
float scale = scales[row * num_groups + g];
float group_acc;
switch (depth) {
case 1: group_acc = trit_mac_d1(group_data, group_x, lane); break;
case 2: group_acc = trit_mac_d2(group_data, group_x, lane); break;
case 3: group_acc = trit_mac_d3(group_data, group_x, lane); break;
case 4: group_acc = trit_mac_d4(group_data, group_x, lane); break;
default: group_acc = 0.0f; break;
}
#pragma unroll
for (int offset = 16; offset > 0; offset >>= 1) {
group_acc += __shfl_down_sync(0xFFFFFFFF, group_acc, offset);
}
if (lane == 0) {
row_acc += group_acc * scale;
}
}
if (lane == 0) {
y[row] = row_acc;
}
}