|
|
|
|
|
|
|
|
|
|
|
#include <cassert> |
|
#include <cstddef> |
|
#include <limits> |
|
|
|
#include <xnnpack.h> |
|
#include <xnnpack/aarch64-assembler.h> |
|
#include <xnnpack/gemm.h> |
|
#include <xnnpack/log.h> |
|
#include <xnnpack/memory.h> |
|
#include <xnnpack/microparams.h> |
|
#include <xnnpack/post-operation.h> |
|
|
|
namespace xnnpack { |
|
namespace aarch64 { |
|
namespace { |
|
class Generator : public MacroAssembler { |
|
using MacroAssembler::MacroAssembler; |
|
|
|
public: |
|
void generate(size_t max_mr, size_t nc_mod_nr, size_t kc, const jit_gemm_params* jit_gemm_params); |
|
void perform_post_operations(size_t max_mr, size_t num_post_operations, const xnn_post_operation* post_operations); |
|
}; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
void Generator::generate(size_t max_mr, size_t nc_mod_nr, size_t kc, const jit_gemm_params* jit_gemm_params) |
|
{ |
|
assert(max_mr <= 1); |
|
assert(nc_mod_nr < 8); |
|
assert(kc != 0); |
|
assert(kc % sizeof(float) == 0); |
|
|
|
Label l0, l1, l2, l3, l4, l5, l6, l7; |
|
const size_t num_post_operations = jit_gemm_params->num_post_operations; |
|
const xnn_post_operation* post_operations = jit_gemm_params->post_operations; |
|
const float min = jit_gemm_params->f32_minmax.min; |
|
const float max = jit_gemm_params->f32_minmax.max; |
|
const bool clamp_min = min != -std::numeric_limits<float>::infinity(); |
|
const bool clamp_max = max != +std::numeric_limits<float>::infinity(); |
|
assert(num_post_operations == 0 || (!clamp_min && !clamp_max)); |
|
|
|
|
|
ldp(x14, x8, mem[sp]); |
|
|
|
|
|
if (clamp_min || clamp_max) { |
|
ld2r({v4.v4s(), v5.v4s()}, mem[x8]); |
|
} |
|
bind(l0); |
|
|
|
ldp(q16, q17, mem[x5], 32); |
|
subs(x0, x2, 8); |
|
|
|
b_lo(l3); |
|
|
|
|
|
bind(l1); |
|
ldr(d0, mem[x3], 8); |
|
ldp(q20, q21, mem[x5], 32); |
|
ldp(q22, q23, mem[x5], 32); |
|
subs(x0, x0, 8); |
|
fmla(v16.v4s(), v20.v4s(), v0.s()[0]); |
|
fmla(v17.v4s(), v21.v4s(), v0.s()[0]); |
|
fmla(v16.v4s(), v22.v4s(), v0.s()[1]); |
|
fmla(v17.v4s(), v23.v4s(), v0.s()[1]); |
|
b_hs(l1); |
|
|
|
|
|
tbnz(x0, 2, l3); |
|
|
|
bind(l2); |
|
subs(x1, x1, 8); |
|
|
|
|
|
if (clamp_min) { |
|
fmax(v16.v4s(), v16.v4s(), v4.v4s()); |
|
fmax(v17.v4s(), v17.v4s(), v4.v4s()); |
|
} |
|
if (clamp_max) { |
|
fmin(v16.v4s(), v16.v4s(), v5.v4s()); |
|
fmin(v17.v4s(), v17.v4s(), v5.v4s()); |
|
} |
|
perform_post_operations(max_mr, num_post_operations, post_operations); |
|
|
|
|
|
b_lo(l4); |
|
|
|
stp(q16, q17, mem[x6]); |
|
add(x6, x6, x14); |
|
|
|
sub(x3, x3, x2); |
|
b_hi(l0); |
|
ret(); |
|
|
|
bind(l3); |
|
|
|
ldr(s0, mem[x3], 4); |
|
ldp(q20, q21, mem[x5], 32); |
|
fmla(v16.v4s(), v20.v4s(), v0.s()[0]); |
|
fmla(v17.v4s(), v21.v4s(), v0.s()[0]); |
|
b(l2); |
|
|
|
|
|
bind(l4); |
|
tbz(x1, 2, l5); |
|
str(q16, mem[x6], 16); |
|
mov(v16.v16b(), v17.v16b()); |
|
|
|
bind(l5); |
|
tbz(x1, 1, l6); |
|
str(d16, mem[x6], 8); |
|
dup(d16, v16.d()[1]); |
|
|
|
bind(l6); |
|
tbz(x1, 0, l7); |
|
str(s16, mem[x6]); |
|
bind(l7); |
|
ret(); |
|
|
|
align(16, AlignInstruction::kHlt); |
|
} |
|
|
|
void Generator::perform_post_operations( |
|
size_t max_mr, |
|
size_t num_post_operations, |
|
const xnn_post_operation* post_operations) |
|
{ |
|
if (num_post_operations == 0) { |
|
return; |
|
} |
|
for (size_t i = 0; i < num_post_operations; i++) { |
|
switch (post_operations[i].op_type) { |
|
case xnn_post_operation_type_hardswish: { |
|
|
|
const auto sixth = v0.v4s(); |
|
const auto three = v1.v4s(); |
|
const auto six = v2.v4s(); |
|
const auto zero = v3.v4s(); |
|
|
|
ld3r({sixth, three, six}, mem[x8]++); |
|
movi(zero, 0); |
|
const VRegister accs[] = { |
|
v16.v4s(), v17.v4s(), |
|
}; |
|
const VRegister tmps[] = {v4.v4s(), v5.v4s()}; |
|
f32_hardswish(sixth, three, six, zero, &accs[0], XNN_COUNT_OF(accs), &tmps[0], XNN_COUNT_OF(tmps)); |
|
break; |
|
} |
|
default: |
|
XNN_LOG_UNREACHABLE("unsupported post operation: %u", post_operations[i].op_type); |
|
} |
|
} |
|
} |
|
|
|
} |
|
} |
|
} |
|
|
|
xnn_status_t xnn_generate_f32_gemm_ukernel_1x8__aarch64_neonfma_ld64(xnn_code_buffer* code, size_t max_mr, size_t nc_mod_nr, size_t kc, const void* params) { |
|
using namespace xnnpack::aarch64; |
|
Generator g(code); |
|
assert(params != nullptr); |
|
g.generate(max_mr, nc_mod_nr, kc, static_cast<const jit_gemm_params*>(params)); |
|
g.finalize(); |
|
if (g.error() != xnnpack::Error::kNoError) { |
|
return xnn_status_invalid_state; |
|
} |
|
return xnn_status_success; |
|
} |
|
|