// Copyright 2019 Google LLC // // This source code is licensed under the BSD-style license found in the // LICENSE file in the root directory of this source tree. #include #include #include #include #include #include #include #include #include namespace xnnpack { namespace aarch64 { namespace { class Generator : public MacroAssembler { using MacroAssembler::MacroAssembler; public: void generate(size_t max_mr, size_t nc_mod_nr, size_t kc, size_t ks, const jit_gemm_params* jit_gemm_params); }; // void xnn_f16_igemm_minmax_ukernel_6x16__asm_aarch64_neonfp16arith_cortex_a55r0( // size_t mr, x0 // size_t nc, x1 // size_t kc, x2 / x0 // size_t ks, x3 / x9 // const void** restrict a, x4 // const void* restrict w, x5 // uint8_t* restrict c, x6 // size_t cm_stride, x7 // size_t cn_stride, [sp] -> (x0) // size_t a_offset, [sp + 8] -> x11 // const void* zero, [sp + 16] -> x12 // const xnn_f16_minmax_params params [sp + 24] -> (x8) // d8-d15, x19-x30 need to be preserved if used. x18 is reserved by the OS. // Register usage // A0 x14 v0 v3 // A1 x15 v0[1] v3[1] // A2 x20 v1 v4 // A3 x21 v1[1] v4[1] // A4 x22 v2 v5 // A5 x23 v2[1] v5[1] // B x5 v12 v13 v14 v15 second set of B // B v16 v17 v18 v19 first set // C0 x6 v20 v21 // C1 x16 v22 v23 // C2 x17 v24 v25 // C3 x10 v26 v27 // C4 x13 v28 v29 // C5 x7 v30 v31 // clamp v6, (v4), (v5) // unused v7 v8 v9 v10 v11 // temporary vector shadow register x8 // Converted from: src/f16-igemm/f16-igemm-6x16-minmax-asm-aarch64-neonfp16arith-cortex-a55r0.S void Generator::generate(size_t max_mr, size_t nc_mod_nr, size_t kc, size_t ks, const jit_gemm_params* jit_gemm_params) { assert(max_mr <= 6); assert(nc_mod_nr < 16); assert(kc != 0); assert(kc % sizeof(uint16_t) == 0); assert(ks != 0); Label l0, l1, l2, l3, l4, l5, l6, l7, l8, l9, l10, l11; const size_t num_post_operations = jit_gemm_params->num_post_operations; (void) num_post_operations; // Silence unused warning. const uint16_t min = jit_gemm_params->f16_minmax.min; const uint16_t max = jit_gemm_params->f16_minmax.max; const bool clamp_min = min != UINT16_C(0xFC00); // -Inf. const bool clamp_max = max != UINT16_C(0x7C00); // Inf. assert(num_post_operations == 0 || (!clamp_min && !clamp_max)); // Load zero, params pointer ldp(x12, x8, mem[sp, 16]); // Clamp C pointers if (max_mr > 1) { cmp(x0, 2); // if mr < 2 add(x16, x6, x7); // c1 = c0 + cm_stride csel(x16, x6, x16, kLO); // c1 = c0 } if (max_mr > 2) { add(x17, x16, x7); // c2 = c1 + cm_stride // if mr <= 2 csel(x17, x16, x17, kLS); // c2 = c1 } // Load params ldr(s6, mem[x8]); if (max_mr > 3) { cmp(x0, 4); // if mr < 4 add(x10, x17, x7); // c3 = c2 + cm_stride csel(x10, x17, x10, kLO); // c3 = c2 } if (max_mr > 4) { add(x13, x10, x7); // c4 = c3 + cm_stride // if mr <= 4 csel(x13, x10, x13, kLS); // c4 = c3 } if (max_mr > 5) { cmp(x0, 6); // if mr < 6 add(x7, x13, x7); // c5 = c4 + cm_stride csel(x7, x13, x7, kLO); // c5 = c4 } // Load a_offset ldr(x11, mem[sp, 8]); // Save x20-x23, d12-d15 on stack stp(d12, d13, mem[sp, -64]++); stp(d14, d15, mem[sp, 16]); stp(x20, x21, mem[sp, 32]); stp(x22, x23, mem[sp, 48]); bind(l0); // Load initial bias from w into accumulators ldp(q20, q21, mem[x5], 32); mov(x9, x3); // p = ks if (max_mr > 1) { mov(v22.v16b(), v20.v16b()); mov(v23.v16b(), v21.v16b()); } if (max_mr > 2) { mov(v24.v16b(), v20.v16b()); mov(v25.v16b(), v21.v16b()); } if (max_mr > 3) { mov(v26.v16b(), v20.v16b()); mov(v27.v16b(), v21.v16b()); } if (max_mr > 4) { mov(v28.v16b(), v20.v16b()); mov(v29.v16b(), v21.v16b()); } if (max_mr > 5) { mov(v30.v16b(), v20.v16b()); mov(v31.v16b(), v21.v16b()); } bind(l1); // Load next 6 A pointers if (max_mr == 1) { ldr(x14, mem[x4], 8); } if (max_mr > 1) { ldp(x14, x15, mem[x4], 16); } if (max_mr == 3) { ldr(x20, mem[x4], 8); } if (max_mr > 3) { ldp(x20, x21, mem[x4], 16); } if (max_mr == 5) { ldr(x22, mem[x4], 8); } if (max_mr > 5) { ldp(x22, x23, mem[x4], 16); } cmp(x14, x12); // if a0 == zero add(x14, x14, x11); // a0 += a_offset csel(x14, x12, x14, kEQ); // a0 = zero, else += a0 + a_offset if (max_mr > 1) { cmp(x15, x12); // if a1 == zero add(x15, x15, x11); // a1 += a_offset csel(x15, x12, x15, kEQ); // a1 = zero, else += a1 + a_offset } if (max_mr > 2) { cmp(x20, x12); // if a2 == zero add(x20, x20, x11); // a2 += a_offset csel(x20, x12, x20, kEQ); // a2 = zero, else += a2 + a_offset } if (max_mr > 3) { cmp(x21, x12); // if a3 == zero add(x21, x21, x11); // a3 += a_offset csel(x21, x12, x21, kEQ); // a3 = zero, else += a3 + a_offset } if (max_mr > 4) { cmp(x22, x12); // if a4 == zero add(x22, x22, x11); // a4 += a_offset csel(x22, x12, x22, kEQ); // a4 = zero, else += a4 + a_offset } if (max_mr > 5) { cmp(x23, x12); // if a5 == zero add(x23, x23, x11); // a5 += a_offset csel(x23, x12, x23, kEQ); // a5 = zero, else += a5 + a_offset } // Is there at least 4 halffloats (8 bytes) for prologue + epilogue? subs(x0, x2, 8); // k = kc - 8 b_lo(l5); // Prologue - First group loads, no FMA ldr(s0, mem[x14], 4); // A0 ldp(q16, q17, mem[x5], 32); // B if (max_mr > 2) { ldr(s1, mem[x20], 4); // A2 } if (max_mr > 4) { ldr(s2, mem[x22], 4); // A4 } if (max_mr > 1) { ld1({v0.s()}, 2, mem[x15], 4); // A1 } if (max_mr > 3) { ld1({v1.s()}, 2, mem[x21], 4); // A3 } if (max_mr > 5) { ld1({v2.s()}, 2, mem[x23], 4); // A5 } ldr(q18, mem[x5], 16); ldr(d19, mem[x5], 8); ldr(x8, mem[x5], 8); // ins is in BLOCK 0 subs(x0, x0, 8); // Is there at least 4 halffloats (8 bytes) for main loop? b_lo(l3); align(8); // Main loop - 4 halffloats of A (8 bytes) // 48 FMA + 12 LD32 A + 8 LDR B bind(l2); // First group of 24 FMA, Second group loads // BLOCK 0 ldr(s3, mem[x14], 4); // A0 ins(v19.d()[1], x8); // B from second group fmla(v20.v8h(), v16.v8h(), v0.h()[0]); if (max_mr > 1) { ldr(w8, mem[x15], 4); // A1 fmla(v22.v8h(), v16.v8h(), v0.h()[4]); } if (max_mr > 2) { fmla(v24.v8h(), v16.v8h(), v1.h()[0]); } // BLOCK 1 ldr(d12, mem[x5]); if (max_mr > 1) { ins(v3.d()[1], x8); // A1 ins } if (max_mr > 3) { fmla(v26.v8h(), v16.v8h(), v1.h()[4]); } ldr(x8, mem[x5, 8]); // B if (max_mr > 4) { fmla(v28.v8h(), v16.v8h(), v2.h()[0]); } if (max_mr > 5) { fmla(v30.v8h(), v16.v8h(), v2.h()[4]); } // BLOCK 2 if (max_mr > 2) { ldr(s4, mem[x20], 4); // A2 } ins(v12.d()[1], x8); // B ins fmla(v21.v8h(), v17.v8h(), v0.h()[0]); if (max_mr > 3) { ldr(w8, mem[x21], 4); // A3 } if (max_mr > 1) { fmla(v23.v8h(), v17.v8h(), v0.h()[4]); } if (max_mr > 2) { fmla(v25.v8h(), v17.v8h(), v1.h()[0]); } // BLOCK 3 if (max_mr > 4) { ldr(s5, mem[x22], 4); // A4 } if (max_mr > 3) { ins(v4.d()[1], x8); // A3 ins fmla(v27.v8h(), v17.v8h(), v1.h()[4]); } if (max_mr > 5) { ldr(w8, mem[x23], 4); // A5 } if (max_mr > 4) { fmla(v29.v8h(), v17.v8h(), v2.h()[0]); } if (max_mr > 5) { fmla(v31.v8h(), v17.v8h(), v2.h()[4]); } // BLOCK 4 ldr(d13, mem[x5, 16]); if (max_mr > 5) { ins(v5.d()[1], x8); // A5 ins } fmla(v20.v8h(), v18.v8h(), v0.h()[1]); ldr(x8, mem[x5, 24]); if (max_mr > 1) { fmla(v22.v8h(), v18.v8h(), v0.h()[5]); } if (max_mr > 2) { fmla(v24.v8h(), v18.v8h(), v1.h()[1]); } // BLOCK 5 ldr(d14, mem[x5, 32]); ins(v13.d()[1], x8); // B if (max_mr > 3) { fmla(v26.v8h(), v18.v8h(), v1.h()[5]); } ldr(x8, mem[x5, 40]); if (max_mr > 4) { fmla(v28.v8h(), v18.v8h(), v2.h()[1]); } if (max_mr > 5) { fmla(v30.v8h(), v18.v8h(), v2.h()[5]); } // BLOCK 6 ldr(d15, mem[x5, 48]); ins(v14.d()[1], x8); // B fmla(v21.v8h(), v19.v8h(), v0.h()[1]); ldr(x8, mem[x5, 56]); if (max_mr > 1) { fmla(v23.v8h(), v19.v8h(), v0.h()[5]); } if (max_mr > 2) { fmla(v25.v8h(), v19.v8h(), v1.h()[1]); } // BLOCK 7 ins(v15.d()[1], x8); if (max_mr > 3) { fmla(v27.v8h(), v19.v8h(), v1.h()[5]); } if (max_mr > 4) { fmla(v29.v8h(), v19.v8h(), v2.h()[1]); } if (max_mr > 5) { fmla(v31.v8h(), v19.v8h(), v2.h()[5]); } // Second group of 24 FMA, First group of loads // BLOCK 0 ldr(s0, mem[x14], 4); // A0 fmla(v20.v8h(), v12.v8h(), v3.h()[0]); if (max_mr > 1) { ldr(w8, mem[x15], 4); // A1 fmla(v22.v8h(), v12.v8h(), v3.h()[4]); } if (max_mr > 2) { fmla(v24.v8h(), v12.v8h(), v4.h()[0]); } // BLOCK 1 ldr(d16, mem[x5, 64]); if (max_mr > 1) { ins(v0.d()[1], x8); // A1 ins } if (max_mr > 3) { fmla(v26.v8h(), v12.v8h(), v4.h()[4]); } ldr(x8, mem[x5, 72]); // B if (max_mr > 4) { fmla(v28.v8h(), v12.v8h(), v5.h()[0]); } if (max_mr > 5) { fmla(v30.v8h(), v12.v8h(), v5.h()[4]); } // BLOCK 2 if (max_mr > 2) { ldr(s1, mem[x20], 4); // A2 } ins(v16.d()[1], x8); // B fmla(v21.v8h(), v13.v8h(), v3.h()[0]); if (max_mr > 3) { ldr(w8, mem[x21], 4); // A3 } if (max_mr > 1) { fmla(v23.v8h(), v13.v8h(), v3.h()[4]); } if (max_mr > 2) { fmla(v25.v8h(), v13.v8h(), v4.h()[0]); } // BLOCK 3 if (max_mr > 4) { ldr(s2, mem[x22], 4); // A4 } if (max_mr > 3) { ins(v1.d()[1], x8); // A3 ins fmla(v27.v8h(), v13.v8h(), v4.h()[4]); } if (max_mr > 5) { ldr(w8, mem[x23], 4); // A5 } if (max_mr > 4) { fmla(v29.v8h(), v13.v8h(), v5.h()[0]); } if (max_mr > 5) { fmla(v31.v8h(), v13.v8h(), v5.h()[4]); } // BLOCK 4 ldr(d17, mem[x5, 80]); if (max_mr > 5) { ins(v2.d()[1], x8); // A5 ins } fmla(v20.v8h(), v14.v8h(), v3.h()[1]); ldr(x8, mem[x5, 88]); if (max_mr > 1) { fmla(v22.v8h(), v14.v8h(), v3.h()[5]); } if (max_mr > 2) { fmla(v24.v8h(), v14.v8h(), v4.h()[1]); } // BLOCK 5 ldr(d18, mem[x5, 96]); ins(v17.d()[1], x8); // B if (max_mr > 3) { fmla(v26.v8h(), v14.v8h(), v4.h()[5]); } ldr(x8, mem[x5, 104]); if (max_mr > 4) { fmla(v28.v8h(), v14.v8h(), v5.h()[1]); } if (max_mr > 5) { fmla(v30.v8h(), v14.v8h(), v5.h()[5]); } // BLOCK 6 ldr(d19, mem[x5, 112]); ins(v18.d()[1], x8); // B fmla(v21.v8h(), v15.v8h(), v3.h()[1]); ldr(x8, mem[x5, 120]); if (max_mr > 1) { fmla(v23.v8h(), v15.v8h(), v3.h()[5]); } if (max_mr > 2) { fmla(v25.v8h(), v15.v8h(), v4.h()[1]); } // BLOCK 7 subs(x0, x0, 8); // LDR lands here if (max_mr > 3) { fmla(v27.v8h(), v15.v8h(), v4.h()[5]); } if (max_mr > 4) { fmla(v29.v8h(), v15.v8h(), v5.h()[1]); } add(x5, x5, 128); if (max_mr > 5) { fmla(v31.v8h(), v15.v8h(), v5.h()[5]); } b_hs(l2); // Epilogue - 4 halffloats of A (8 bytes) // 48 FMA + 12 LD32 A + 8 LDR B bind(l3); // First group of 24 FMA, Second group loads // BLOCK 0 ldr(s3, mem[x14], 4); // A0 ins(v19.d()[1], x8); // B from second group fmla(v20.v8h(), v16.v8h(), v0.h()[0]); if (max_mr > 1) { ldr(w8, mem[x15], 4); // A1 fmla(v22.v8h(), v16.v8h(), v0.h()[4]); } if (max_mr > 2) { fmla(v24.v8h(), v16.v8h(), v1.h()[0]); } // BLOCK 1 ldr(d12, mem[x5]); if (max_mr > 1) { ins(v3.d()[1], x8); // A1 ins } if (max_mr > 3) { fmla(v26.v8h(), v16.v8h(), v1.h()[4]); } ldr(x8, mem[x5, 8]); // B if (max_mr > 4) { fmla(v28.v8h(), v16.v8h(), v2.h()[0]); } if (max_mr > 5) { fmla(v30.v8h(), v16.v8h(), v2.h()[4]); } // BLOCK 2 if (max_mr > 2) { ldr(s4, mem[x20], 4); // A2 } ins(v12.d()[1], x8); // B ins fmla(v21.v8h(), v17.v8h(), v0.h()[0]); if (max_mr > 3) { ldr(w8, mem[x21], 4); // A3 } if (max_mr > 1) { fmla(v23.v8h(), v17.v8h(), v0.h()[4]); } if (max_mr > 2) { fmla(v25.v8h(), v17.v8h(), v1.h()[0]); } // BLOCK 3 if (max_mr > 4) { ldr(s5, mem[x22], 4); // A4 } if (max_mr > 3) { ins(v4.d()[1], x8); // A3 ins fmla(v27.v8h(), v17.v8h(), v1.h()[4]); } if (max_mr > 5) { ldr(w8, mem[x23], 4); // A5 } if (max_mr > 4) { fmla(v29.v8h(), v17.v8h(), v2.h()[0]); } if (max_mr > 5) { fmla(v31.v8h(), v17.v8h(), v2.h()[4]); } // BLOCK 4 ldr(d13, mem[x5, 16]); if (max_mr > 5) { ins(v5.d()[1], x8); // A5 ins } fmla(v20.v8h(), v18.v8h(), v0.h()[1]); ldr(x8, mem[x5, 24]); if (max_mr > 1) { fmla(v22.v8h(), v18.v8h(), v0.h()[5]); } if (max_mr > 2) { fmla(v24.v8h(), v18.v8h(), v1.h()[1]); } // BLOCK 5 ldr(d14, mem[x5, 32]); ins(v13.d()[1], x8); // B if (max_mr > 3) { fmla(v26.v8h(), v18.v8h(), v1.h()[5]); } ldr(x8, mem[x5, 40]); if (max_mr > 4) { fmla(v28.v8h(), v18.v8h(), v2.h()[1]); } if (max_mr > 5) { fmla(v30.v8h(), v18.v8h(), v2.h()[5]); } // BLOCK 6 ldr(d15, mem[x5, 48]); ins(v14.d()[1], x8); // B fmla(v21.v8h(), v19.v8h(), v0.h()[1]); ldr(x8, mem[x5, 56]); if (max_mr > 1) { fmla(v23.v8h(), v19.v8h(), v0.h()[5]); } if (max_mr > 2) { fmla(v25.v8h(), v19.v8h(), v1.h()[1]); } // BLOCK 7 ins(v15.d()[1], x8); // B if (max_mr > 3) { fmla(v27.v8h(), v19.v8h(), v1.h()[5]); } if (max_mr > 4) { fmla(v29.v8h(), v19.v8h(), v2.h()[1]); } if (max_mr > 5) { fmla(v31.v8h(), v19.v8h(), v2.h()[5]); } // Second group of 24 FMA, First group of loads // BLOCK 0 fmla(v20.v8h(), v12.v8h(), v3.h()[0]); if (max_mr > 1) { fmla(v22.v8h(), v12.v8h(), v3.h()[4]); } if (max_mr > 2) { fmla(v24.v8h(), v12.v8h(), v4.h()[0]); } // BLOCK 1 if (max_mr > 3) { fmla(v26.v8h(), v12.v8h(), v4.h()[4]); } if (max_mr > 4) { fmla(v28.v8h(), v12.v8h(), v5.h()[0]); } if (max_mr > 5) { fmla(v30.v8h(), v12.v8h(), v5.h()[4]); } // BLOCK 2 fmla(v21.v8h(), v13.v8h(), v3.h()[0]); if (max_mr > 1) { fmla(v23.v8h(), v13.v8h(), v3.h()[4]); } if (max_mr > 2) { fmla(v25.v8h(), v13.v8h(), v4.h()[0]); } // BLOCK 3 if (max_mr > 3) { fmla(v27.v8h(), v13.v8h(), v4.h()[4]); } if (max_mr > 4) { fmla(v29.v8h(), v13.v8h(), v5.h()[0]); } if (max_mr > 5) { fmla(v31.v8h(), v13.v8h(), v5.h()[4]); } // BLOCK 4 fmla(v20.v8h(), v14.v8h(), v3.h()[1]); if (max_mr > 1) { fmla(v22.v8h(), v14.v8h(), v3.h()[5]); } if (max_mr > 2) { fmla(v24.v8h(), v14.v8h(), v4.h()[1]); } // BLOCK 5 if (max_mr > 3) { fmla(v26.v8h(), v14.v8h(), v4.h()[5]); } if (max_mr > 4) { fmla(v28.v8h(), v14.v8h(), v5.h()[1]); } if (max_mr > 5) { fmla(v30.v8h(), v14.v8h(), v5.h()[5]); } tst(x0, 7); // BLOCK 6 fmla(v21.v8h(), v15.v8h(), v3.h()[1]); if (max_mr > 1) { fmla(v23.v8h(), v15.v8h(), v3.h()[5]); } if (max_mr > 2) { fmla(v25.v8h(), v15.v8h(), v4.h()[1]); } add(x5, x5, 64); // BLOCK 7 if (max_mr > 3) { fmla(v27.v8h(), v15.v8h(), v4.h()[5]); } if (max_mr > 4) { fmla(v29.v8h(), v15.v8h(), v5.h()[1]); } if (max_mr > 5) { fmla(v31.v8h(), v15.v8h(), v5.h()[5]); } // Is there a remainder?- 2 halffloats of A (4 bytes) or less b_ne(l5); bind(l4); // ks loop subs(x9, x9, max_mr * sizeof(void*)); // ks -= MR * sizeof(void*) b_hi(l1); // Clamp dup(v4.v8h(), v6.h()[0]); dup(v5.v8h(), v6.h()[1]); ldr(x0, mem[sp, 64]); // cn_stride if (clamp_min) { fmax(v20.v8h(), v20.v8h(), v4.v8h()); fmax(v21.v8h(), v21.v8h(), v4.v8h()); if (max_mr > 1) { fmax(v22.v8h(), v22.v8h(), v4.v8h()); fmax(v23.v8h(), v23.v8h(), v4.v8h()); } if (max_mr > 2) { fmax(v24.v8h(), v24.v8h(), v4.v8h()); fmax(v25.v8h(), v25.v8h(), v4.v8h()); } if (max_mr > 3) { fmax(v26.v8h(), v26.v8h(), v4.v8h()); fmax(v27.v8h(), v27.v8h(), v4.v8h()); } if (max_mr > 4) { fmax(v28.v8h(), v28.v8h(), v4.v8h()); fmax(v29.v8h(), v29.v8h(), v4.v8h()); } if (max_mr > 5) { fmax(v30.v8h(), v30.v8h(), v4.v8h()); fmax(v31.v8h(), v31.v8h(), v4.v8h()); } } subs(x1, x1, 16); if (clamp_max) { fmin(v20.v8h(), v20.v8h(), v5.v8h()); fmin(v21.v8h(), v21.v8h(), v5.v8h()); if (max_mr > 1) { fmin(v22.v8h(), v22.v8h(), v5.v8h()); fmin(v23.v8h(), v23.v8h(), v5.v8h()); } if (max_mr > 2) { fmin(v24.v8h(), v24.v8h(), v5.v8h()); fmin(v25.v8h(), v25.v8h(), v5.v8h()); } if (max_mr > 3) { fmin(v26.v8h(), v26.v8h(), v5.v8h()); fmin(v27.v8h(), v27.v8h(), v5.v8h()); } if (max_mr > 4) { fmin(v28.v8h(), v28.v8h(), v5.v8h()); fmin(v29.v8h(), v29.v8h(), v5.v8h()); } if (max_mr > 5) { fmin(v30.v8h(), v30.v8h(), v5.v8h()); fmin(v31.v8h(), v31.v8h(), v5.v8h()); } } // Store full 6 x 16 b_lo(l7); if (max_mr > 5) { st1({v30.v16b(), v31.v16b()}, mem[x7], x0); } if (max_mr > 4) { st1({v28.v16b(), v29.v16b()}, mem[x13], x0); } if (max_mr > 3) { st1({v26.v16b(), v27.v16b()}, mem[x10], x0); } if (max_mr > 2) { st1({v24.v16b(), v25.v16b()}, mem[x17], x0); } if (max_mr > 1) { st1({v22.v16b(), v23.v16b()}, mem[x16], x0); } st1({v20.v16b(), v21.v16b()}, mem[x6], x0); sub(x4, x4, x3); // a -= ks // nc loop b_hi(l0); // Restore x20-x23, d12-d15 from stack ldp(x22, x23, mem[sp, 48]); ldp(x20, x21, mem[sp, 32]); ldp(d14, d15, mem[sp, 16]); ldp(d12, d13, mem[sp], 64); ret(); bind(l5); // Is there a remainder?- 2 halffloats of A (4 bytes) tbz(x0, 2, l6); // Remainder- 2 halffloats of A (4 bytes) ldr(s0, mem[x14], 4); // A0 ldp(q16, q17, mem[x5], 32); // B if (max_mr > 2) { ldr(s1, mem[x20], 4); // A2 } if (max_mr > 4) { ldr(s2, mem[x22], 4); // A4 } if (max_mr > 1) { ld1({v0.s()}, 2, mem[x15], 4); // A1 } if (max_mr > 3) { ld1({v1.s()}, 2, mem[x21], 4); // A3 } if (max_mr > 5) { ld1({v2.s()}, 2, mem[x23], 4); // A5 } ldr(q18, mem[x5], 16); ldr(q19, mem[x5], 16); fmla(v20.v8h(), v16.v8h(), v0.h()[0]); if (max_mr > 1) { fmla(v22.v8h(), v16.v8h(), v0.h()[4]); } if (max_mr > 2) { fmla(v24.v8h(), v16.v8h(), v1.h()[0]); } if (max_mr > 3) { fmla(v26.v8h(), v16.v8h(), v1.h()[4]); } if (max_mr > 4) { fmla(v28.v8h(), v16.v8h(), v2.h()[0]); } if (max_mr > 5) { fmla(v30.v8h(), v16.v8h(), v2.h()[4]); } fmla(v21.v8h(), v17.v8h(), v0.h()[0]); if (max_mr > 1) { fmla(v23.v8h(), v17.v8h(), v0.h()[4]); } if (max_mr > 2) { fmla(v25.v8h(), v17.v8h(), v1.h()[0]); } if (max_mr > 3) { fmla(v27.v8h(), v17.v8h(), v1.h()[4]); } if (max_mr > 4) { fmla(v29.v8h(), v17.v8h(), v2.h()[0]); } if (max_mr > 5) { fmla(v31.v8h(), v17.v8h(), v2.h()[4]); } fmla(v20.v8h(), v18.v8h(), v0.h()[1]); if (max_mr > 1) { fmla(v22.v8h(), v18.v8h(), v0.h()[5]); } if (max_mr > 2) { fmla(v24.v8h(), v18.v8h(), v1.h()[1]); } if (max_mr > 3) { fmla(v26.v8h(), v18.v8h(), v1.h()[5]); } if (max_mr > 4) { fmla(v28.v8h(), v18.v8h(), v2.h()[1]); } if (max_mr > 5) { fmla(v30.v8h(), v18.v8h(), v2.h()[5]); } fmla(v21.v8h(), v19.v8h(), v0.h()[1]); if (max_mr > 1) { fmla(v23.v8h(), v19.v8h(), v0.h()[5]); } if (max_mr > 2) { fmla(v25.v8h(), v19.v8h(), v1.h()[1]); } if (max_mr > 3) { fmla(v27.v8h(), v19.v8h(), v1.h()[5]); } if (max_mr > 4) { fmla(v29.v8h(), v19.v8h(), v2.h()[1]); } if (max_mr > 5) { fmla(v31.v8h(), v19.v8h(), v2.h()[5]); } // Is there a remainder?- 1 halffloat of A (2 bytes) tbz(x0, 1, l4); bind(l6); // Remainder- 1 halffloat of A (2 bytes) ldr(h0, mem[x14], 2); // A0 ldp(q16, q17, mem[x5], 32); // B if (max_mr > 2) { ldr(h1, mem[x20], 2); // A2 } if (max_mr > 4) { ldr(h2, mem[x22], 2); // A4 } if (max_mr > 1) { ld1({v0.h()}, 4, mem[x15], 2); // A1 } if (max_mr > 3) { ld1({v1.h()}, 4, mem[x21], 2); // A3 } if (max_mr > 5) { ld1({v2.h()}, 4, mem[x23], 2); // A5 } fmla(v20.v8h(), v16.v8h(), v0.h()[0]); if (max_mr > 1) { fmla(v22.v8h(), v16.v8h(), v0.h()[4]); } if (max_mr > 2) { fmla(v24.v8h(), v16.v8h(), v1.h()[0]); } if (max_mr > 3) { fmla(v26.v8h(), v16.v8h(), v1.h()[4]); } if (max_mr > 4) { fmla(v28.v8h(), v16.v8h(), v2.h()[0]); } if (max_mr > 5) { fmla(v30.v8h(), v16.v8h(), v2.h()[4]); } fmla(v21.v8h(), v17.v8h(), v0.h()[0]); if (max_mr > 1) { fmla(v23.v8h(), v17.v8h(), v0.h()[4]); } if (max_mr > 2) { fmla(v25.v8h(), v17.v8h(), v1.h()[0]); } if (max_mr > 3) { fmla(v27.v8h(), v17.v8h(), v1.h()[4]); } if (max_mr > 4) { fmla(v29.v8h(), v17.v8h(), v2.h()[0]); } if (max_mr > 5) { fmla(v31.v8h(), v17.v8h(), v2.h()[4]); } b(l4); // Store odd width bind(l7); tbz(x1, 3, l8); if (max_mr > 5) { str(q30, mem[x7], 16); mov(v30.v16b(), v31.v16b()); } if (max_mr > 4) { str(q28, mem[x13], 16); mov(v28.v16b(), v29.v16b()); } if (max_mr > 3) { str(q26, mem[x10], 16); mov(v26.v16b(), v27.v16b()); } if (max_mr > 2) { str(q24, mem[x17], 16); mov(v24.v16b(), v25.v16b()); } if (max_mr > 1) { str(q22, mem[x16], 16); mov(v22.v16b(), v23.v16b()); } str(q20, mem[x6], 16); mov(v20.v16b(), v21.v16b()); bind(l8); tbz(x1, 2, l9); if (max_mr > 5) { str(d30, mem[x7], 8); } if (max_mr > 4) { str(d28, mem[x13], 8); } if (max_mr > 5) { dup(d30, v30.d()[1]); } if (max_mr > 4) { dup(d28, v28.d()[1]); } if (max_mr > 3) { str(d26, mem[x10], 8); } if (max_mr > 2) { str(d24, mem[x17], 8); } if (max_mr > 3) { dup(d26, v26.d()[1]); } if (max_mr > 2) { dup(d24, v24.d()[1]); } if (max_mr > 1) { str(d22, mem[x16], 8); } str(d20, mem[x6], 8); if (max_mr > 1) { dup(d22, v22.d()[1]); } dup(d20, v20.d()[1]); bind(l9); tbz(x1, 1, l10); if (max_mr > 5) { str(s30, mem[x7], 4); } if (max_mr > 4) { str(s28, mem[x13], 4); } if (max_mr > 5) { dup(s30, v30.s()[1]); } if (max_mr > 4) { dup(s28, v28.s()[1]); } if (max_mr > 3) { str(s26, mem[x10], 4); } if (max_mr > 2) { str(s24, mem[x17], 4); } if (max_mr > 3) { dup(s26, v26.s()[1]); } if (max_mr > 2) { dup(s24, v24.s()[1]); } if (max_mr > 1) { str(s22, mem[x16], 4); } str(s20, mem[x6], 4); if (max_mr > 1) { dup(s22, v22.s()[1]); } dup(s20, v20.s()[1]); bind(l10); tbz(x1, 0, l11); if (max_mr > 5) { str(h30, mem[x7]); } if (max_mr > 4) { str(h28, mem[x13]); } if (max_mr > 3) { str(h26, mem[x10]); } if (max_mr > 2) { str(h24, mem[x17]); } if (max_mr > 1) { str(h22, mem[x16]); } str(h20, mem[x6]); bind(l11); // Restore x20-x23, d12-d15 from stack ldp(x22, x23, mem[sp, 48]); ldp(x20, x21, mem[sp, 32]); ldp(d14, d15, mem[sp, 16]); ldp(d12, d13, mem[sp], 64); ret(); align(16, AlignInstruction::kHlt); } } // namespace } // namespace aarch64 } // namespace xnnpack xnn_status_t xnn_generate_f16_igemm_ukernel_6x16__aarch64_neonfp16arith_cortex_a55r0(xnn_code_buffer* code, size_t max_mr, size_t nc_mod_nr, size_t kc, size_t ks, const void* params) { using namespace xnnpack::aarch64; Generator g(code); assert(params != nullptr); g.generate(max_mr, nc_mod_nr, kc, ks, static_cast(params)); g.finalize(); if (g.error() != xnnpack::Error::kNoError) { return xnn_status_invalid_state; } return xnn_status_success; }