File size: 1,298 Bytes
57e3690
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
#version 450

#include "common.comp"

#extension GL_KHR_shader_subgroup_arithmetic : require
#extension GL_EXT_debug_printf : enable

// device subgroup size
layout (local_size_x_id = 0) in;

layout(binding = 0) readonly buffer tensorInA { float inA[]; };
layout(binding = 1) readonly buffer tensorInB { float inB[]; };
layout(binding = 2) writeonly buffer tensorOut { float out_[]; };

layout(push_constant) uniform parameter {
  uint inAOff;
  uint inBOff;
  uint outOff;
  int ne00;
  int ne01;
  int ne02;
  int ne11;
  int ne12;
  uint nb01;
  uint nb02;
  uint nb11;
  uint nb12;
  uint nb1;
  uint nb2;
}
pcs;


void main() {
  uvec3 gid = gl_WorkGroupID;

  uint bc_ab = pcs.ne12 > pcs.ne02 ? gid.z / (pcs.ne12 / pcs.ne02) : gid.z;
  uint bc_ba = pcs.ne02 > pcs.ne12 ? gid.z / (pcs.ne02 / pcs.ne12) : gid.z;

  const uint x = (gid.x*pcs.nb01 + bc_ab*pcs.nb02) / 4 + pcs.inAOff; // Based from inA
  const uint y = (gid.y*pcs.nb11 + bc_ba*pcs.nb12) / 4 + pcs.inBOff; // based from inB
  float sum = 0.0f;
  for (uint i = gl_SubgroupInvocationID.x; i < pcs.ne00; i += gl_SubgroupSize) {
      sum += float(inA[x+i]) * float(inB[y+i]);
  }

  const float all_sum = subgroupAdd(sum);
  if (subgroupElect()) {
    out_[gid.z*(pcs.nb2/4) + gid.y*(pcs.nb1/4) + gid.x + pcs.outOff] = all_sum;
  }
}