perf: optimized AVX2 kernel + COM6-inspired matmul dispatch (0.2 -> 3.43 t/s)
Browse files- ggml/src/ggml-cpu/arch/x86/quants.c +47 -43
- ggml/src/ggml-cpu/ggml-cpu.c +12 -6
ggml/src/ggml-cpu/arch/x86/quants.c
CHANGED
|
@@ -65,52 +65,56 @@ static inline int hsum_i32_4(const __m128i a) {
|
|
| 65 |
return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32));
|
| 66 |
}
|
| 67 |
|
| 68 |
-
#if defined(__AVX2__)
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
const __m256i
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
const __m256i
|
| 80 |
-
|
| 81 |
-
0x0101010101010101, 0x0000000000000000);
|
| 82 |
-
__m256i bytes = _mm256_shuffle_epi8(_mm256_set1_epi32(x32), shuf_mask);
|
| 83 |
-
const __m256i bit_mask = _mm256_set1_epi64x(0x7fbfdfeff7fbfdfe);
|
| 84 |
-
bytes = _mm256_or_si256(bytes, bit_mask);
|
| 85 |
-
return _mm256_cmpeq_epi8(bytes, _mm256_set1_epi64x(-1));
|
| 86 |
-
}
|
| 87 |
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 97 |
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
|
|
|
| 104 |
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
return _mm256_cvtepi32_ps(summed_pairs);
|
| 114 |
#else
|
| 115 |
// Perform multiplication and create 16-bit values
|
| 116 |
const __m256i dot = _mm256_maddubs_epi16(ax, sy);
|
|
|
|
| 65 |
return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32));
|
| 66 |
}
|
| 67 |
|
| 68 |
+
#if defined(__AVX2__)
|
| 69 |
+
// AVX2: single-pass byte-level processing, fully unrolled k-loop.
|
| 70 |
+
// Pipeline: broadcast+shuffle -> AND+cmpeq -> XOR+SUB -> maddubs+madd -> cvt+fma
|
| 71 |
+
const __m256i ones_8 = _mm256_set1_epi8(1);
|
| 72 |
+
const __m256i ones_16 = _mm256_set1_epi16(1);
|
| 73 |
+
const __m256i byte_shuf = _mm256_setr_epi8(
|
| 74 |
+
0,0,0,0,0,0,0,0, 1,1,1,1,1,1,1,1,
|
| 75 |
+
2,2,2,2,2,2,2,2, 3,3,3,3,3,3,3,3);
|
| 76 |
+
const __m256i bit_masks = _mm256_setr_epi8(
|
| 77 |
+
1,2,4,8,16,32,64,-128, 1,2,4,8,16,32,64,-128,
|
| 78 |
+
1,2,4,8,16,32,64,-128, 1,2,4,8,16,32,64,-128);
|
| 79 |
+
const __m256i zero = _mm256_setzero_si256();
|
| 80 |
+
__m256 acc = _mm256_setzero_ps();
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
|
| 82 |
+
for (int ib = 0; ib < nb; ++ib) {
|
| 83 |
+
const float d0 = GGML_CPU_FP16_TO_FP32(x[ib].d);
|
| 84 |
+
const uint32_t * qs32 = (const uint32_t *)x[ib].qs;
|
| 85 |
+
|
| 86 |
+
#define Q1_AVX2_BLOCK(K) \
|
| 87 |
+
{ \
|
| 88 |
+
const __m256i y = _mm256_loadu_si256((const __m256i *)y_ptr[K].qs); \
|
| 89 |
+
const __m256i sm = _mm256_cmpeq_epi8(_mm256_and_si256( \
|
| 90 |
+
_mm256_shuffle_epi8(_mm256_set1_epi32((int)qs32[K]), byte_shuf), \
|
| 91 |
+
bit_masks), zero); \
|
| 92 |
+
const __m256i sy = _mm256_sub_epi8(_mm256_xor_si256(y, sm), sm); \
|
| 93 |
+
const __m256i s32 = _mm256_madd_epi16( \
|
| 94 |
+
_mm256_maddubs_epi16(ones_8, sy), ones_16); \
|
| 95 |
+
acc_block = (K == 0) \
|
| 96 |
+
? _mm256_mul_ps(_mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y_ptr[K].d)), \
|
| 97 |
+
_mm256_cvtepi32_ps(s32)) \
|
| 98 |
+
: _mm256_fmadd_ps(_mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y_ptr[K].d)), \
|
| 99 |
+
_mm256_cvtepi32_ps(s32), acc_block); \
|
| 100 |
+
}
|
| 101 |
|
| 102 |
+
const block_q8_0 * y_ptr = &y[ib*4];
|
| 103 |
+
__m256 acc_block;
|
| 104 |
+
Q1_AVX2_BLOCK(0)
|
| 105 |
+
Q1_AVX2_BLOCK(1)
|
| 106 |
+
Q1_AVX2_BLOCK(2)
|
| 107 |
+
Q1_AVX2_BLOCK(3)
|
| 108 |
+
#undef Q1_AVX2_BLOCK
|
| 109 |
|
| 110 |
+
acc = _mm256_fmadd_ps(_mm256_set1_ps(d0), acc_block, acc);
|
| 111 |
+
}
|
| 112 |
+
{
|
| 113 |
+
const __m128 h = _mm_add_ps(_mm256_extractf128_ps(acc, 0),
|
| 114 |
+
_mm256_extractf128_ps(acc, 1));
|
| 115 |
+
const __m128 q = _mm_add_ps(h, _mm_movehl_ps(h, h));
|
| 116 |
+
*s = _mm_cvtss_f32(_mm_add_ss(q, _mm_movehdup_ps(q)));
|
| 117 |
+
}
|
|
|
|
| 118 |
#else
|
| 119 |
// Perform multiplication and create 16-bit values
|
| 120 |
const __m256i dot = _mm256_maddubs_epi16(ax, sy);
|
ggml/src/ggml-cpu/ggml-cpu.c
CHANGED
|
@@ -1185,15 +1185,16 @@ static void ggml_compute_forward_mul_mat_one_chunk(
|
|
| 1185 |
assert(ne12 % ne02 == 0);
|
| 1186 |
assert(ne13 % ne03 == 0);
|
| 1187 |
|
| 1188 |
-
// block-tiling
|
| 1189 |
-
|
|
|
|
| 1190 |
const int64_t blck_1 = 16;
|
| 1191 |
|
| 1192 |
const size_t src1_col_stride = src1_cont || src1->type != vec_dot_type ? row_size : nb11;
|
| 1193 |
|
| 1194 |
// attempt to reduce false-sharing (does not seem to make a difference)
|
| 1195 |
-
//
|
| 1196 |
-
float tmp[
|
| 1197 |
|
| 1198 |
for (int64_t iir1 = ir1_start; iir1 < ir1_end; iir1 += blck_1) {
|
| 1199 |
for (int64_t iir0 = ir0_start; iir0 < ir0_end; iir0 += blck_0) {
|
|
@@ -1226,12 +1227,17 @@ static void ggml_compute_forward_mul_mat_one_chunk(
|
|
| 1226 |
// vec_dot(ne00, &dst_col[ir0], src0_row + ir0*nb01, src1_col);
|
| 1227 |
//}
|
| 1228 |
|
| 1229 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1230 |
vec_dot(ne00, &tmp[ir0 - iir0], (num_rows_per_vec_dot > 1 ? 16 : 0), src0_row + ir0 * nb01, (num_rows_per_vec_dot > 1 ? nb01 : 0), src1_col, (num_rows_per_vec_dot > 1 ? src1_col_stride : 0), num_rows_per_vec_dot);
|
| 1231 |
}
|
| 1232 |
|
| 1233 |
for (int cn = 0; cn < num_rows_per_vec_dot; ++cn) {
|
| 1234 |
-
memcpy(&dst_col[iir0 + cn * nb1 / nb0], tmp + (cn * 16), (
|
| 1235 |
}
|
| 1236 |
}
|
| 1237 |
}
|
|
|
|
| 1185 |
assert(ne12 % ne02 == 0);
|
| 1186 |
assert(ne13 % ne03 == 0);
|
| 1187 |
|
| 1188 |
+
// COM6-inspired block-tiling: larger blocks for Q1_0_g128 (1-bit weights are tiny,
|
| 1189 |
+
// so we can fit more rows in L1). Prefetch next weight block while processing current.
|
| 1190 |
+
const int64_t blck_0 = (type == GGML_TYPE_Q1_0_g128) ? 64 : 16;
|
| 1191 |
const int64_t blck_1 = 16;
|
| 1192 |
|
| 1193 |
const size_t src1_col_stride = src1_cont || src1->type != vec_dot_type ? row_size : nb11;
|
| 1194 |
|
| 1195 |
// attempt to reduce false-sharing (does not seem to make a difference)
|
| 1196 |
+
// Size: blck_0 * 2 (accounting for mmla kernels that compute 2 rows at once)
|
| 1197 |
+
float tmp[128];
|
| 1198 |
|
| 1199 |
for (int64_t iir1 = ir1_start; iir1 < ir1_end; iir1 += blck_1) {
|
| 1200 |
for (int64_t iir0 = ir0_start; iir0 < ir0_end; iir0 += blck_0) {
|
|
|
|
| 1227 |
// vec_dot(ne00, &dst_col[ir0], src0_row + ir0*nb01, src1_col);
|
| 1228 |
//}
|
| 1229 |
|
| 1230 |
+
// COM6-inspired: prefetch next weight rows while computing current ones.
|
| 1231 |
+
const int64_t ir0_max = MIN(iir0 + blck_0, ir0_end);
|
| 1232 |
+
for (int64_t ir0 = iir0; ir0 < ir0_max; ir0 += num_rows_per_vec_dot) {
|
| 1233 |
+
if (ir0 + 4 * num_rows_per_vec_dot < ir0_max) {
|
| 1234 |
+
__builtin_prefetch(src0_row + (ir0 + 4 * num_rows_per_vec_dot) * nb01, 0, 1);
|
| 1235 |
+
}
|
| 1236 |
vec_dot(ne00, &tmp[ir0 - iir0], (num_rows_per_vec_dot > 1 ? 16 : 0), src0_row + ir0 * nb01, (num_rows_per_vec_dot > 1 ? nb01 : 0), src1_col, (num_rows_per_vec_dot > 1 ? src1_col_stride : 0), num_rows_per_vec_dot);
|
| 1237 |
}
|
| 1238 |
|
| 1239 |
for (int cn = 0; cn < num_rows_per_vec_dot; ++cn) {
|
| 1240 |
+
memcpy(&dst_col[iir0 + cn * nb1 / nb0], tmp + (cn * 16), (ir0_max - iir0) * sizeof(float));
|
| 1241 |
}
|
| 1242 |
}
|
| 1243 |
}
|