| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| #include <cuda_fp16.h> |
| #include <cuda_runtime.h> |
| #include <stdint.h> |
|
|
| #define GROUP_SIZE 64 |
| #define WARP_SIZE 32 |
|
|
| |
| |
| #define TRIT_ZERO 0 |
| #define TRIT_POS 1 |
| #define TRIT_NEG 2 |
|
|
| |
| |
| |
| |
| |
| |
| |
| __device__ __forceinline__ float trit_mac_d1( |
| const uint32_t* __restrict__ packed, |
| const float* __restrict__ x, |
| int lane |
| ) { |
| float acc = 0.0f; |
|
|
| |
| #pragma unroll |
| for (int i = 0; i < 2; i++) { |
| int idx = lane * 2 + i; |
| int word = idx / 16; |
| int bit_offset = (idx % 16) * 2; |
|
|
| uint32_t trit = (packed[word] >> bit_offset) & 0x3; |
| float val = x[idx]; |
|
|
| |
| acc += ((trit == TRIT_POS) - (trit == TRIT_NEG)) * val; |
| } |
|
|
| return acc; |
| } |
|
|
| |
| |
| |
| |
| |
| |
| __device__ __forceinline__ float trit_mac_d2( |
| const uint32_t* __restrict__ packed, |
| const float* __restrict__ x, |
| int lane |
| ) { |
| float acc = 0.0f; |
|
|
| #pragma unroll |
| for (int i = 0; i < 2; i++) { |
| int idx = lane * 2 + i; |
| int word = idx / 8; |
| int bit_offset = (idx % 8) * 4; |
|
|
| uint32_t bits = (packed[word] >> bit_offset) & 0xF; |
| |
| |
| |
| int t0 = (int)(bits & 0x3); |
| int t1 = (int)((bits >> 2) & 0x3); |
| int sign0 = (t0 == TRIT_POS) - (t0 == TRIT_NEG); |
| int sign1 = (t1 == TRIT_POS) - (t1 == TRIT_NEG); |
| int level = sign1 * 3 + sign0; |
|
|
| |
| |
| acc += level * x[idx]; |
| } |
|
|
| return acc; |
| } |
|
|
| |
| |
| |
| |
| |
| __device__ __forceinline__ float trit_mac_d3( |
| const uint32_t* __restrict__ packed, |
| const float* __restrict__ x, |
| int lane |
| ) { |
| float acc = 0.0f; |
|
|
| #pragma unroll |
| for (int i = 0; i < 2; i++) { |
| int idx = lane * 2 + i; |
| int word = idx / 5; |
| int pos = idx % 5; |
| int bit_offset = pos * 6; |
|
|
| uint32_t bits = (packed[word] >> bit_offset) & 0x3F; |
| int t0 = (int)(bits & 0x3); |
| int t1 = (int)((bits >> 2) & 0x3); |
| int t2 = (int)((bits >> 4) & 0x3); |
| int s0 = (t0 == TRIT_POS) - (t0 == TRIT_NEG); |
| int s1 = (t1 == TRIT_POS) - (t1 == TRIT_NEG); |
| int s2 = (t2 == TRIT_POS) - (t2 == TRIT_NEG); |
| int level = s2 * 9 + s1 * 3 + s0; |
|
|
| acc += level * x[idx]; |
| } |
|
|
| return acc; |
| } |
|
|
| |
| |
| |
| |
| |
| __device__ __forceinline__ float trit_mac_d4( |
| const uint32_t* __restrict__ packed, |
| const float* __restrict__ x, |
| int lane |
| ) { |
| float acc = 0.0f; |
|
|
| #pragma unroll |
| for (int i = 0; i < 2; i++) { |
| int idx = lane * 2 + i; |
| int word = idx / 4; |
| int bit_offset = (idx % 4) * 8; |
|
|
| uint32_t bits = (packed[word] >> bit_offset) & 0xFF; |
| int t0 = (int)(bits & 0x3); |
| int t1 = (int)((bits >> 2) & 0x3); |
| int t2 = (int)((bits >> 4) & 0x3); |
| int t3 = (int)((bits >> 6) & 0x3); |
| int s0 = (t0 == TRIT_POS) - (t0 == TRIT_NEG); |
| int s1 = (t1 == TRIT_POS) - (t1 == TRIT_NEG); |
| int s2 = (t2 == TRIT_POS) - (t2 == TRIT_NEG); |
| int s3 = (t3 == TRIT_POS) - (t3 == TRIT_NEG); |
| int level = s3 * 27 + s2 * 9 + s1 * 3 + s0; |
|
|
| acc += level * x[idx]; |
| } |
|
|
| return acc; |
| } |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| __global__ void trit_gemv_uniform( |
| const uint32_t* __restrict__ packed_trits, |
| const float* __restrict__ scales, |
| const float* __restrict__ x, |
| float* __restrict__ y, |
| int in_features, |
| int out_features, |
| int depth |
| ) { |
| if (blockDim.x != WARP_SIZE) return; |
| if (in_features % GROUP_SIZE) return; |
|
|
| int row = blockIdx.x; |
| if (row >= out_features) return; |
|
|
| int lane = threadIdx.x; |
| int num_groups = in_features / GROUP_SIZE; |
|
|
| |
| int words_per_group; |
| switch (depth) { |
| case 1: words_per_group = 4; break; |
| case 2: words_per_group = 8; break; |
| case 3: words_per_group = 13; break; |
| case 4: words_per_group = 16; break; |
| default: words_per_group = 4; break; |
| } |
|
|
| float row_acc = 0.0f; |
|
|
| for (int g = 0; g < num_groups; g++) { |
| int group_offset = (row * num_groups + g) * words_per_group; |
| const uint32_t* group_data = &packed_trits[group_offset]; |
| const float* group_x = &x[g * GROUP_SIZE]; |
| float scale = scales[row * num_groups + g]; |
|
|
| float group_acc; |
| switch (depth) { |
| case 1: group_acc = trit_mac_d1(group_data, group_x, lane); break; |
| case 2: group_acc = trit_mac_d2(group_data, group_x, lane); break; |
| case 3: group_acc = trit_mac_d3(group_data, group_x, lane); break; |
| case 4: group_acc = trit_mac_d4(group_data, group_x, lane); break; |
| default: group_acc = 0.0f; break; |
| } |
|
|
| |
| #pragma unroll |
| for (int offset = 16; offset > 0; offset >>= 1) { |
| group_acc += __shfl_down_sync(0xFFFFFFFF, group_acc, offset); |
| } |
|
|
| |
| if (lane == 0) { |
| row_acc += group_acc * scale; |
| } |
| } |
|
|
| |
| if (lane == 0) { |
| y[row] = row_acc; |
| } |
| } |
|
|
| |
| |
| |
| |
| |
| __global__ void trit_gemv_variable( |
| const uint32_t* __restrict__ packed_trits, |
| const float* __restrict__ scales, |
| const uint8_t* __restrict__ depth_map, |
| const int* __restrict__ group_offsets, |
| const float* __restrict__ x, |
| float* __restrict__ y, |
| int in_features, |
| int out_features |
| ) { |
| if (blockDim.x != WARP_SIZE) return; |
| if (in_features % GROUP_SIZE) return; |
|
|
| int row = blockIdx.x; |
| if (row >= out_features) return; |
|
|
| int lane = threadIdx.x; |
| int num_groups = in_features / GROUP_SIZE; |
|
|
| float row_acc = 0.0f; |
|
|
| for (int g = 0; g < num_groups; g++) { |
| int depth = depth_map[g]; |
| int word_offset = group_offsets[g] + row * group_offsets[num_groups]; |
| const uint32_t* group_data = &packed_trits[word_offset]; |
| const float* group_x = &x[g * GROUP_SIZE]; |
| float scale = scales[row * num_groups + g]; |
|
|
| float group_acc; |
| switch (depth) { |
| case 1: group_acc = trit_mac_d1(group_data, group_x, lane); break; |
| case 2: group_acc = trit_mac_d2(group_data, group_x, lane); break; |
| case 3: group_acc = trit_mac_d3(group_data, group_x, lane); break; |
| case 4: group_acc = trit_mac_d4(group_data, group_x, lane); break; |
| default: group_acc = 0.0f; break; |
| } |
|
|
| #pragma unroll |
| for (int offset = 16; offset > 0; offset >>= 1) { |
| group_acc += __shfl_down_sync(0xFFFFFFFF, group_acc, offset); |
| } |
|
|
| if (lane == 0) { |
| row_acc += group_acc * scale; |
| } |
| } |
|
|
| if (lane == 0) { |
| y[row] = row_acc; |
| } |
| } |
|
|