|
|
|
|
|
|
|
|
|
|
|
#include <cassert> |
|
#include <cstddef> |
|
#include <limits> |
|
|
|
#include <xnnpack.h> |
|
#include <xnnpack/aarch64-assembler.h> |
|
#include <xnnpack/gemm.h> |
|
#include <xnnpack/memory.h> |
|
#include <xnnpack/microparams.h> |
|
#include <xnnpack/post-operation.h> |
|
|
|
namespace xnnpack { |
|
namespace aarch64 { |
|
namespace { |
|
class Generator : public MacroAssembler { |
|
using MacroAssembler::MacroAssembler; |
|
|
|
public: |
|
void generate(size_t max_mr, size_t nc_mod_nr, size_t kc, const jit_gemm_params* jit_gemm_params); |
|
}; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
void Generator::generate(size_t max_mr, size_t nc_mod_nr, size_t kc, const jit_gemm_params* jit_gemm_params) |
|
{ |
|
assert(max_mr <= 6); |
|
assert(nc_mod_nr < 16); |
|
assert(kc != 0); |
|
assert(kc % sizeof(uint16_t) == 0); |
|
|
|
Label l0, l1, l2, l3, l4, l5, l6, l7, l8, l9; |
|
const size_t num_post_operations = jit_gemm_params->num_post_operations; |
|
(void) num_post_operations; |
|
const uint16_t min = jit_gemm_params->f16_minmax.min; |
|
const uint16_t max = jit_gemm_params->f16_minmax.max; |
|
const bool clamp_min = min != UINT16_C(0xFC00); |
|
const bool clamp_max = max != UINT16_C(0x7C00); |
|
assert(num_post_operations == 0 || (!clamp_min && !clamp_max)); |
|
|
|
|
|
ldr(x8, mem[sp, 8]); |
|
|
|
|
|
if (max_mr > 1) { |
|
cmp(x0, 2); |
|
add(x9, x3, x4); |
|
add(x16, x6, x7); |
|
csel(x9, x3, x9, kLO); |
|
csel(x16, x6, x16, kLO); |
|
} |
|
|
|
|
|
ldr(s6, mem[x8]); |
|
|
|
if (max_mr > 2) { |
|
add(x10, x9, x4); |
|
add(x17, x16, x7); |
|
|
|
csel(x10, x9, x10, kLS); |
|
csel(x17, x16, x17, kLS); |
|
} |
|
|
|
if (max_mr > 3) { |
|
cmp(x0, 4); |
|
add(x11, x10, x4); |
|
add(x14, x17, x7); |
|
csel(x11, x10, x11, kLO); |
|
csel(x14, x17, x14, kLO); |
|
} |
|
|
|
if (max_mr > 4) { |
|
add(x12, x11, x4); |
|
add(x13, x14, x7); |
|
|
|
csel(x12, x11, x12, kLS); |
|
csel(x13, x14, x13, kLS); |
|
} |
|
|
|
if (max_mr > 5) { |
|
cmp(x0, 6); |
|
add(x4, x12, x4); |
|
add(x7, x13, x7); |
|
csel(x4, x12, x4, kLO); |
|
csel(x7, x13, x7, kLO); |
|
} |
|
|
|
ldr(x8, mem[sp]); |
|
|
|
bind(l0); |
|
|
|
ldp(q20, q21, mem[x5], 32); |
|
if (max_mr > 1) { |
|
mov(v22.v16b(), v20.v16b()); |
|
mov(v23.v16b(), v21.v16b()); |
|
} |
|
if (max_mr > 2) { |
|
mov(v24.v16b(), v20.v16b()); |
|
mov(v25.v16b(), v21.v16b()); |
|
} |
|
if (max_mr > 3) { |
|
mov(v26.v16b(), v20.v16b()); |
|
mov(v27.v16b(), v21.v16b()); |
|
} |
|
if (max_mr > 4) { |
|
mov(v28.v16b(), v20.v16b()); |
|
mov(v29.v16b(), v21.v16b()); |
|
} |
|
if (max_mr > 5) { |
|
mov(v30.v16b(), v20.v16b()); |
|
mov(v31.v16b(), v21.v16b()); |
|
} |
|
|
|
|
|
subs(x0, x2, 4); |
|
b_lo(l4); |
|
|
|
|
|
|
|
ldr(s0, mem[x3], 4); |
|
ldr(q16, mem[x5], 16); |
|
ldr(q17, mem[x5], 16); |
|
if (max_mr > 1) { |
|
ldr(s1, mem[x9], 4); |
|
} |
|
if (max_mr > 2) { |
|
ldr(s2, mem[x10], 4); |
|
} |
|
if (max_mr > 3) { |
|
ldr(s3, mem[x11], 4); |
|
} |
|
|
|
|
|
subs(x0, x0, 4); |
|
b_lo(l2); |
|
|
|
align(8); |
|
|
|
|
|
bind(l1); |
|
fmla(v20.v8h(), v16.v8h(), v0.h()[0]); |
|
if (max_mr > 4) { |
|
ldr(s4, mem[x12], 4); |
|
} |
|
fmla(v21.v8h(), v17.v8h(), v0.h()[0]); |
|
if (max_mr > 5) { |
|
ldr(s5, mem[x4], 4); |
|
} |
|
if (max_mr > 1) { |
|
fmla(v22.v8h(), v16.v8h(), v1.h()[0]); |
|
} |
|
ldr(d18, mem[x5], 8); |
|
if (max_mr > 1) { |
|
fmla(v23.v8h(), v17.v8h(), v1.h()[0]); |
|
} |
|
ld1({v18.d()}, 1, mem[x5], 8); |
|
if (max_mr > 2) { |
|
fmla(v24.v8h(), v16.v8h(), v2.h()[0]); |
|
} |
|
ldr(d19, mem[x5], 8); |
|
if (max_mr > 2) { |
|
fmla(v25.v8h(), v17.v8h(), v2.h()[0]); |
|
} |
|
ld1({v19.d()}, 1, mem[x5], 8); |
|
if (max_mr > 3) { |
|
fmla(v26.v8h(), v16.v8h(), v3.h()[0]); |
|
fmla(v27.v8h(), v17.v8h(), v3.h()[0]); |
|
} |
|
if (max_mr > 4) { |
|
fmla(v28.v8h(), v16.v8h(), v4.h()[0]); |
|
fmla(v29.v8h(), v17.v8h(), v4.h()[0]); |
|
} |
|
if (max_mr > 5) { |
|
fmla(v30.v8h(), v16.v8h(), v5.h()[0]); |
|
fmla(v31.v8h(), v17.v8h(), v5.h()[0]); |
|
} |
|
subs(x0, x0, 4); |
|
|
|
fmla(v20.v8h(), v18.v8h(), v0.h()[1]); |
|
ldr(d16, mem[x5], 8); |
|
fmla(v21.v8h(), v19.v8h(), v0.h()[1]); |
|
ld1({v16.d()}, 1, mem[x5], 8); |
|
if (max_mr > 1) { |
|
fmla(v22.v8h(), v18.v8h(), v1.h()[1]); |
|
} |
|
ldr(d17, mem[x5], 8); |
|
if (max_mr > 1) { |
|
fmla(v23.v8h(), v19.v8h(), v1.h()[1]); |
|
} |
|
ld1({v17.d()}, 1, mem[x5], 8); |
|
if (max_mr > 2) { |
|
fmla(v24.v8h(), v18.v8h(), v2.h()[1]); |
|
fmla(v25.v8h(), v19.v8h(), v2.h()[1]); |
|
} |
|
if (max_mr > 3) { |
|
fmla(v26.v8h(), v18.v8h(), v3.h()[1]); |
|
fmla(v27.v8h(), v19.v8h(), v3.h()[1]); |
|
} |
|
ldr(s0, mem[x3], 4); |
|
if (max_mr > 4) { |
|
fmla(v28.v8h(), v18.v8h(), v4.h()[1]); |
|
} |
|
if (max_mr > 1) { |
|
ldr(s1, mem[x9], 4); |
|
} |
|
if (max_mr > 4) { |
|
fmla(v29.v8h(), v19.v8h(), v4.h()[1]); |
|
} |
|
if (max_mr > 2) { |
|
ldr(s2, mem[x10], 4); |
|
} |
|
if (max_mr > 5) { |
|
fmla(v30.v8h(), v18.v8h(), v5.h()[1]); |
|
} |
|
if (max_mr > 3) { |
|
ldr(s3, mem[x11], 4); |
|
} |
|
if (max_mr > 5) { |
|
fmla(v31.v8h(), v19.v8h(), v5.h()[1]); |
|
} |
|
b_hs(l1); |
|
|
|
|
|
bind(l2); |
|
fmla(v20.v8h(), v16.v8h(), v0.h()[0]); |
|
if (max_mr > 4) { |
|
ldr(s4, mem[x12], 4); |
|
} |
|
fmla(v21.v8h(), v17.v8h(), v0.h()[0]); |
|
if (max_mr > 5) { |
|
ldr(s5, mem[x4], 4); |
|
} |
|
if (max_mr > 1) { |
|
fmla(v22.v8h(), v16.v8h(), v1.h()[0]); |
|
} |
|
ldr(d18, mem[x5], 8); |
|
if (max_mr > 1) { |
|
fmla(v23.v8h(), v17.v8h(), v1.h()[0]); |
|
} |
|
ld1({v18.d()}, 1, mem[x5], 8); |
|
if (max_mr > 2) { |
|
fmla(v24.v8h(), v16.v8h(), v2.h()[0]); |
|
} |
|
ldr(d19, mem[x5], 8); |
|
if (max_mr > 2) { |
|
fmla(v25.v8h(), v17.v8h(), v2.h()[0]); |
|
} |
|
ld1({v19.d()}, 1, mem[x5], 8); |
|
if (max_mr > 3) { |
|
fmla(v26.v8h(), v16.v8h(), v3.h()[0]); |
|
fmla(v27.v8h(), v17.v8h(), v3.h()[0]); |
|
} |
|
if (max_mr > 4) { |
|
fmla(v28.v8h(), v16.v8h(), v4.h()[0]); |
|
fmla(v29.v8h(), v17.v8h(), v4.h()[0]); |
|
} |
|
if (max_mr > 5) { |
|
fmla(v30.v8h(), v16.v8h(), v5.h()[0]); |
|
fmla(v31.v8h(), v17.v8h(), v5.h()[0]); |
|
} |
|
|
|
fmla(v20.v8h(), v18.v8h(), v0.h()[1]); |
|
fmla(v21.v8h(), v19.v8h(), v0.h()[1]); |
|
if (max_mr > 1) { |
|
fmla(v22.v8h(), v18.v8h(), v1.h()[1]); |
|
fmla(v23.v8h(), v19.v8h(), v1.h()[1]); |
|
} |
|
if (max_mr > 2) { |
|
fmla(v24.v8h(), v18.v8h(), v2.h()[1]); |
|
fmla(v25.v8h(), v19.v8h(), v2.h()[1]); |
|
} |
|
if (max_mr > 3) { |
|
fmla(v26.v8h(), v18.v8h(), v3.h()[1]); |
|
fmla(v27.v8h(), v19.v8h(), v3.h()[1]); |
|
} |
|
if (max_mr > 4) { |
|
fmla(v28.v8h(), v18.v8h(), v4.h()[1]); |
|
fmla(v29.v8h(), v19.v8h(), v4.h()[1]); |
|
} |
|
if (max_mr > 5) { |
|
fmla(v30.v8h(), v18.v8h(), v5.h()[1]); |
|
fmla(v31.v8h(), v19.v8h(), v5.h()[1]); |
|
} |
|
|
|
|
|
tbnz(x0, 1, l4); |
|
bind(l3); |
|
|
|
dup(v4.v8h(), v6.h()[0]); |
|
dup(v5.v8h(), v6.h()[1]); |
|
if (clamp_min) { |
|
fmax(v20.v8h(), v20.v8h(), v4.v8h()); |
|
fmax(v21.v8h(), v21.v8h(), v4.v8h()); |
|
if (max_mr > 1) { |
|
fmax(v22.v8h(), v22.v8h(), v4.v8h()); |
|
fmax(v23.v8h(), v23.v8h(), v4.v8h()); |
|
} |
|
if (max_mr > 2) { |
|
fmax(v24.v8h(), v24.v8h(), v4.v8h()); |
|
fmax(v25.v8h(), v25.v8h(), v4.v8h()); |
|
} |
|
if (max_mr > 3) { |
|
fmax(v26.v8h(), v26.v8h(), v4.v8h()); |
|
fmax(v27.v8h(), v27.v8h(), v4.v8h()); |
|
} |
|
if (max_mr > 4) { |
|
fmax(v28.v8h(), v28.v8h(), v4.v8h()); |
|
fmax(v29.v8h(), v29.v8h(), v4.v8h()); |
|
} |
|
if (max_mr > 5) { |
|
fmax(v30.v8h(), v30.v8h(), v4.v8h()); |
|
fmax(v31.v8h(), v31.v8h(), v4.v8h()); |
|
} |
|
} |
|
subs(x1, x1, 16); |
|
if (clamp_max) { |
|
fmin(v20.v8h(), v20.v8h(), v5.v8h()); |
|
fmin(v21.v8h(), v21.v8h(), v5.v8h()); |
|
if (max_mr > 1) { |
|
fmin(v22.v8h(), v22.v8h(), v5.v8h()); |
|
fmin(v23.v8h(), v23.v8h(), v5.v8h()); |
|
} |
|
if (max_mr > 2) { |
|
fmin(v24.v8h(), v24.v8h(), v5.v8h()); |
|
fmin(v25.v8h(), v25.v8h(), v5.v8h()); |
|
} |
|
if (max_mr > 3) { |
|
fmin(v26.v8h(), v26.v8h(), v5.v8h()); |
|
fmin(v27.v8h(), v27.v8h(), v5.v8h()); |
|
} |
|
if (max_mr > 4) { |
|
fmin(v28.v8h(), v28.v8h(), v5.v8h()); |
|
fmin(v29.v8h(), v29.v8h(), v5.v8h()); |
|
} |
|
if (max_mr > 5) { |
|
fmin(v30.v8h(), v30.v8h(), v5.v8h()); |
|
fmin(v31.v8h(), v31.v8h(), v5.v8h()); |
|
} |
|
} |
|
|
|
|
|
b_lo(l5); |
|
|
|
st1({v20.v16b(), v21.v16b()}, mem[x6], x8); |
|
sub(x3, x3, x2); |
|
if (max_mr > 1) { |
|
st1({v22.v16b(), v23.v16b()}, mem[x16], x8); |
|
sub(x9, x9, x2); |
|
} |
|
if (max_mr > 2) { |
|
st1({v24.v16b(), v25.v16b()}, mem[x17], x8); |
|
sub(x10, x10, x2); |
|
} |
|
if (max_mr > 3) { |
|
st1({v26.v16b(), v27.v16b()}, mem[x14], x8); |
|
sub(x11, x11, x2); |
|
} |
|
if (max_mr > 4) { |
|
st1({v28.v16b(), v29.v16b()}, mem[x13], x8); |
|
sub(x12, x12, x2); |
|
} |
|
if (max_mr > 5) { |
|
st1({v30.v16b(), v31.v16b()}, mem[x7], x8); |
|
sub(x4, x4, x2); |
|
} |
|
|
|
b_hi(l0); |
|
ret(); |
|
|
|
bind(l4); |
|
|
|
ldr(h0, mem[x3], 2); |
|
ldr(q16, mem[x5], 16); |
|
ldr(q17, mem[x5], 16); |
|
fmla(v20.v8h(), v16.v8h(), v0.h()[0]); |
|
if (max_mr > 1) { |
|
ldr(h1, mem[x9], 2); |
|
fmla(v22.v8h(), v16.v8h(), v1.h()[0]); |
|
} |
|
if (max_mr > 2) { |
|
ldr(h2, mem[x10], 2); |
|
fmla(v24.v8h(), v16.v8h(), v2.h()[0]); |
|
} |
|
if (max_mr > 3) { |
|
ldr(h3, mem[x11], 2); |
|
fmla(v26.v8h(), v16.v8h(), v3.h()[0]); |
|
} |
|
if (max_mr > 4) { |
|
ldr(h4, mem[x12], 2); |
|
fmla(v28.v8h(), v16.v8h(), v4.h()[0]); |
|
} |
|
if (max_mr > 5) { |
|
ldr(h5, mem[x4], 2); |
|
fmla(v30.v8h(), v16.v8h(), v5.h()[0]); |
|
} |
|
fmla(v21.v8h(), v17.v8h(), v0.h()[0]); |
|
if (max_mr > 1) { |
|
fmla(v23.v8h(), v17.v8h(), v1.h()[0]); |
|
} |
|
if (max_mr > 2) { |
|
fmla(v25.v8h(), v17.v8h(), v2.h()[0]); |
|
} |
|
if (max_mr > 3) { |
|
fmla(v27.v8h(), v17.v8h(), v3.h()[0]); |
|
} |
|
if (max_mr > 4) { |
|
fmla(v29.v8h(), v17.v8h(), v4.h()[0]); |
|
} |
|
if (max_mr > 5) { |
|
fmla(v31.v8h(), v17.v8h(), v5.h()[0]); |
|
} |
|
b(l3); |
|
|
|
|
|
bind(l5); |
|
tbz(x1, 3, l6); |
|
str(q20, mem[x6], 16); |
|
mov(v20.v16b(), v21.v16b()); |
|
if (max_mr > 1) { |
|
str(q22, mem[x16], 16); |
|
mov(v22.v16b(), v23.v16b()); |
|
} |
|
if (max_mr > 2) { |
|
str(q24, mem[x17], 16); |
|
mov(v24.v16b(), v25.v16b()); |
|
} |
|
if (max_mr > 3) { |
|
str(q26, mem[x14], 16); |
|
mov(v26.v16b(), v27.v16b()); |
|
} |
|
if (max_mr > 4) { |
|
str(q28, mem[x13], 16); |
|
mov(v28.v16b(), v29.v16b()); |
|
} |
|
if (max_mr > 5) { |
|
str(q30, mem[x7], 16); |
|
mov(v30.v16b(), v31.v16b()); |
|
} |
|
|
|
bind(l6); |
|
tbz(x1, 2, l7); |
|
str(d20, mem[x6], 8); |
|
if (max_mr > 1) { |
|
str(d22, mem[x16], 8); |
|
} |
|
dup(d20, v20.d()[1]); |
|
if (max_mr > 1) { |
|
dup(d22, v22.d()[1]); |
|
} |
|
if (max_mr > 2) { |
|
str(d24, mem[x17], 8); |
|
} |
|
if (max_mr > 3) { |
|
str(d26, mem[x14], 8); |
|
} |
|
if (max_mr > 2) { |
|
dup(d24, v24.d()[1]); |
|
} |
|
if (max_mr > 3) { |
|
dup(d26, v26.d()[1]); |
|
} |
|
if (max_mr > 4) { |
|
str(d28, mem[x13], 8); |
|
} |
|
if (max_mr > 5) { |
|
str(d30, mem[x7], 8); |
|
} |
|
if (max_mr > 4) { |
|
dup(d28, v28.d()[1]); |
|
} |
|
if (max_mr > 5) { |
|
dup(d30, v30.d()[1]); |
|
} |
|
|
|
bind(l7); |
|
tbz(x1, 1, l8); |
|
str(s20, mem[x6], 4); |
|
if (max_mr > 1) { |
|
str(s22, mem[x16], 4); |
|
} |
|
dup(s20, v20.s()[1]); |
|
if (max_mr > 1) { |
|
dup(s22, v22.s()[1]); |
|
} |
|
if (max_mr > 2) { |
|
str(s24, mem[x17], 4); |
|
} |
|
if (max_mr > 3) { |
|
str(s26, mem[x14], 4); |
|
} |
|
if (max_mr > 2) { |
|
dup(s24, v24.s()[1]); |
|
} |
|
if (max_mr > 3) { |
|
dup(s26, v26.s()[1]); |
|
} |
|
if (max_mr > 4) { |
|
str(s28, mem[x13], 4); |
|
} |
|
if (max_mr > 5) { |
|
str(s30, mem[x7], 4); |
|
} |
|
if (max_mr > 4) { |
|
dup(s28, v28.s()[1]); |
|
} |
|
if (max_mr > 5) { |
|
dup(s30, v30.s()[1]); |
|
} |
|
|
|
bind(l8); |
|
tbz(x1, 0, l9); |
|
str(h20, mem[x6]); |
|
if (max_mr > 1) { |
|
str(h22, mem[x16]); |
|
} |
|
if (max_mr > 2) { |
|
str(h24, mem[x17]); |
|
} |
|
if (max_mr > 3) { |
|
str(h26, mem[x14]); |
|
} |
|
if (max_mr > 4) { |
|
str(h28, mem[x13]); |
|
} |
|
if (max_mr > 5) { |
|
str(h30, mem[x7]); |
|
} |
|
bind(l9); |
|
ret(); |
|
|
|
align(16, AlignInstruction::kHlt); |
|
} |
|
} |
|
} |
|
} |
|
|
|
xnn_status_t xnn_generate_f16_gemm_ukernel_6x16__aarch64_neonfp16arith_cortex_a55(xnn_code_buffer* code, size_t max_mr, size_t nc_mod_nr, size_t kc, const void* params) { |
|
using namespace xnnpack::aarch64; |
|
Generator g(code); |
|
assert(params != nullptr); |
|
g.generate(max_mr, nc_mod_nr, kc, static_cast<const jit_gemm_params*>(params)); |
|
g.finalize(); |
|
if (g.error() != xnnpack::Error::kNoError) { |
|
return xnn_status_invalid_state; |
|
} |
|
return xnn_status_success; |
|
} |
|
|