File size: 2,243 Bytes
57e3690
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
#version 450

#include "common.comp"

#include "op_mul_mv_q_n_pre.comp"

#define SIZE_OF_D 2

#define N_DST 4 // each SIMD group works on 4 rows
#define N_SIMDGROUP 2 // number of SIMD groups in a thread group
#define N_SIMDWIDTH 32 // assuming SIMD group size is 32

#define NB_Q8_0 8

void main() {
    // NB: hack to make compatible with AMD GPUs that have a subgroup size of 64
    if (gl_SubgroupInvocationID > 31)
        return;

    const int nr  = N_DST;
    const int nsg = N_SIMDGROUP;
    const int nw  = N_SIMDWIDTH;

    const int nb = pcs.ne00/QK8_0;
    const uint r0 = gl_WorkGroupID.x;
    const uint r1 = gl_WorkGroupID.y;
    const uint im = gl_WorkGroupID.z;

    const uint first_row = (r0 * nsg + gl_SubgroupID) * nr;

    const uint i12 = im%pcs.ne12;
    const uint i13 = im/pcs.ne12;

    const uint offset0 = first_row * nb + (i12/pcs.r2)*(nb*pcs.ne01) + (i13/pcs.r3)*(nb*pcs.ne01*pcs.ne02);

    const uint x = offset0*sizeof_block_q8_0 + pcs.inAOff; // Based from inA
    const uint y = r1*pcs.ne10 + im*pcs.ne00*pcs.ne1 + pcs.inBOff; // based from inB

    float yl[NB_Q8_0];
    float sumf[N_DST]={0.f, 0.f, 0.f, 0.f};

    const uint ix = gl_SubgroupInvocationID.x/4;
    const uint il = gl_SubgroupInvocationID.x%4;

    uint yb = y + ix * QK8_0 + NB_Q8_0*il;

    // each thread in a SIMD group deals with NB_Q8_0 quants at a time
    for (uint ib = ix; ib < nb; ib += nw/4) {
        for (int i = 0; i < NB_Q8_0; ++i) {
            yl[i] = inB[yb + i];
        }

        for (int row = 0; row < nr; row++) {
            const uint block_offset = (ib+row*nb) * sizeof_block_q8_0;
            float sumq = 0.f;
            for (int iq = 0; iq < NB_Q8_0; ++iq) {
                const int8_t qs_iq = int8_t(inA[x + block_offset + SIZE_OF_D + NB_Q8_0*il + iq]);
                sumq += qs_iq * yl[iq];
            }
            const float16_t d = u8BufToFloat16(inA, x + block_offset);
            sumf[row] += sumq*d;
        }

        yb += NB_Q8_0 * nw;
    }

    for (int row = 0; row < nr; ++row) {
        const float tot = subgroupAdd(sumf[row]);
        if (subgroupElect() && first_row + row < pcs.ne01) {
            out_[r1*pcs.ne0 + im*pcs.ne0*pcs.ne1 + first_row + row] = tot;
        }
    }
}