|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#ifndef LYRA_CODEC_SPARSE_MATMUL_COMPUTE_KERNELS_ARM_H_ |
|
#define LYRA_CODEC_SPARSE_MATMUL_COMPUTE_KERNELS_ARM_H_ |
|
|
|
#if defined __aarch64__ |
|
|
|
#include <arm_neon.h> |
|
|
|
#include <type_traits> |
|
|
|
#include "sparse_matmul/numerics/fixed_types.h" |
|
#include "sparse_matmul/numerics/float16_types.h" |
|
#include "sparse_matmul/numerics/type_utils.h" |
|
|
|
#define LABEL_COL_LOOP "1" |
|
#define LABEL_ROW_LOOP "2" |
|
#define LABEL_SKIP_COL_LOOP "3" |
|
#define LABEL_TOP_LOOP "4" |
|
|
|
namespace csrblocksparse { |
|
namespace detail { |
|
|
|
template <typename T> |
|
struct IsFloatOrBfloat |
|
: std::integral_constant<bool, std::is_same<T, float>::value || |
|
std::is_same<T, bfloat16>::value> {}; |
|
|
|
template <typename WeightType, typename RhsType, typename OutType> |
|
struct IsAllowableFloatTypes |
|
: std::integral_constant<bool, IsFloatOrBfloat<WeightType>::value && |
|
std::is_same<RhsType, float>::value && |
|
std::is_same<OutType, float>::value> {}; |
|
|
|
|
|
|
|
|
|
template <typename WeightType, typename RhsType, typename OutType> |
|
struct IsAllowableFixedTypes |
|
: std::integral_constant<bool, (IsFixed16Type<WeightType>::value && |
|
IsFixed16Type<RhsType>::value) && |
|
(IsFixed32Type<OutType>::value || |
|
IsFixed16Type<OutType>::value)> {}; |
|
|
|
template <typename WeightType, typename RhsType, typename OutType> |
|
struct ShouldEnableGenericKernel |
|
: std::integral_constant< |
|
bool, |
|
!IsAllowableFloatTypes<WeightType, RhsType, OutType>::value && |
|
!IsAllowableFixedTypes<WeightType, RhsType, OutType>::value> {}; |
|
|
|
template <typename WeightType, typename RhsType, typename OutType> |
|
struct ShouldEnableGenericSpMV_4x4 |
|
: ShouldEnableGenericKernel<WeightType, RhsType, OutType> {}; |
|
template <typename WeightType, typename RhsType, typename OutType> |
|
struct ShouldEnableGenericSpMM5_4x4 |
|
: ShouldEnableGenericKernel<WeightType, RhsType, OutType> {}; |
|
template <typename WeightType, typename RhsType, typename OutType> |
|
struct ShouldEnableGenericSpMV_1x1 : std::true_type {}; |
|
template <typename WeightType, typename RhsType, typename OutType> |
|
struct ShouldEnableGenericSpMM5_1x1 : std::true_type {}; |
|
template <typename Type> |
|
struct IsAddableFixedTypes |
|
: std::integral_constant<bool, IsFixed32Type<Type>::value || |
|
IsFixed16Type<Type>::value> {}; |
|
template <typename Type> |
|
struct ShouldEnableGenericAdd |
|
: std::integral_constant<bool, !IsAddableFixedTypes<Type>::value> {}; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template <typename WeightType, typename RhsType, typename OutType> |
|
typename std::enable_if<std::is_same<WeightType, bfloat16>::value && |
|
std::is_same<RhsType, float>::value && |
|
std::is_same<OutType, float>::value>::type |
|
SpMV_4x4(const bfloat16* weights_ptr, const int16_t* col_deltas_bytes, |
|
const int32_t* nnz_per_row, const float* rhs_ptr, |
|
const float* bias_ptr, float* out_ptr, int64_t assigned_rows, |
|
int64_t rows , |
|
int64_t cols , int relu) { |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if (relu) { |
|
asm( |
|
|
|
"ldrsh x7, [%[col_deltas_bytes]], #2\n" |
|
"ldrsh x8, [%[col_deltas_bytes]], #2\n" |
|
|
|
"add %[rhs_ptr], %[rhs_ptr], x7\n" |
|
|
|
"movi v25.4s, #0\n" |
|
|
|
LABEL_ROW_LOOP |
|
":\n" |
|
|
|
|
|
"ld1 {v27.4s}, [%[bias_ptr]], #16\n" |
|
|
|
|
|
"dup v28.4s, v27.s[0]\n" |
|
"dup v29.4s, v27.s[1]\n" |
|
"dup v30.4s, v27.s[2]\n" |
|
"dup v31.4s, v27.s[3]\n" |
|
|
|
|
|
"ldr w6, [%[nnz_per_row]], #4\n" |
|
"cmp w6, #0\n" |
|
|
|
"beq " LABEL_SKIP_COL_LOOP "f\n" |
|
|
|
LABEL_COL_LOOP |
|
":\n" |
|
|
|
"ld1 {v0.4s}, [%[rhs_ptr]], x8\n" |
|
|
|
|
|
"ldrsh x8, [%[col_deltas_bytes]], #2\n" |
|
|
|
|
|
"ld1 {v2.8h, v3.8h}, [%[weights_ptr]], #32\n" |
|
|
|
|
|
"shll v4.4s, v2.4h, #16\n" |
|
"shll2 v5.4s, v2.8h, #16\n" |
|
"shll v6.4s, v3.4h, #16\n" |
|
"shll2 v7.4s, v3.8h, #16\n" |
|
|
|
|
|
"fmla v28.4s, v4.4s, v0.4s\n" |
|
"fmla v29.4s, v5.4s, v0.4s\n" |
|
"fmla v30.4s, v6.4s, v0.4s\n" |
|
"fmla v31.4s, v7.4s, v0.4s\n" |
|
|
|
|
|
"subs w6, w6, #1\n" |
|
"bne " LABEL_COL_LOOP "b\n" |
|
|
|
LABEL_SKIP_COL_LOOP |
|
":\n" |
|
|
|
|
|
"faddp v28.4s, v28.4s, v29.4s\n" |
|
"faddp v30.4s, v30.4s, v31.4s\n" |
|
"faddp v28.4s, v28.4s, v30.4s\n" |
|
|
|
|
|
"fmax v28.4s, v28.4s, v25.4s\n" |
|
|
|
|
|
"st1 {v28.4s}, [%[out_ptr]], #16\n" |
|
|
|
|
|
"subs %[assigned_rows], %[assigned_rows], #1\n" |
|
"bne " LABEL_ROW_LOOP "b\n" |
|
|
|
|
|
: |
|
[out_ptr] "+r"(out_ptr), |
|
[weights_ptr] "+r"(weights_ptr), |
|
[col_deltas_bytes] "+r"(col_deltas_bytes), |
|
[bias_ptr] "+r"(bias_ptr), |
|
[nnz_per_row] "+r"(nnz_per_row), |
|
[assigned_rows] "+r"(assigned_rows), |
|
[rhs_ptr] "+r"(rhs_ptr) |
|
: |
|
: |
|
"cc", "memory", "x6", "x7", "x8", "v0", "v1", "v2", "v3", "v4", "v5", |
|
"v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", |
|
"v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", |
|
"v26", "v27", "v28", "v29", "v30", "v31"); |
|
|
|
} else { |
|
asm( |
|
|
|
"ldrsh x7, [%[col_deltas_bytes]], #2\n" |
|
"ldrsh x8, [%[col_deltas_bytes]], #2\n" |
|
|
|
"add %[rhs_ptr], %[rhs_ptr], x7\n" |
|
|
|
LABEL_ROW_LOOP |
|
":\n" |
|
|
|
|
|
"ld1 {v27.4s}, [%[bias_ptr]], #16\n" |
|
|
|
|
|
"dup v28.4s, v27.s[0]\n" |
|
"dup v29.4s, v27.s[1]\n" |
|
"dup v30.4s, v27.s[2]\n" |
|
"dup v31.4s, v27.s[3]\n" |
|
|
|
|
|
"ldr w6, [%[nnz_per_row]], #4\n" |
|
"cmp w6, #0\n" |
|
|
|
"beq " LABEL_SKIP_COL_LOOP "f\n" |
|
|
|
LABEL_COL_LOOP |
|
":\n" |
|
|
|
"ld1 {v0.4s}, [%[rhs_ptr]], x8\n" |
|
|
|
|
|
"ldrsh x8, [%[col_deltas_bytes]], #2\n" |
|
|
|
|
|
"ld1 {v2.8h, v3.8h}, [%[weights_ptr]], #32\n" |
|
|
|
|
|
"shll v4.4s, v2.4h, #16\n" |
|
"shll2 v5.4s, v2.8h, #16\n" |
|
"shll v6.4s, v3.4h, #16\n" |
|
"shll2 v7.4s, v3.8h, #16\n" |
|
|
|
|
|
"fmla v28.4s, v4.4s, v0.4s\n" |
|
"fmla v29.4s, v5.4s, v0.4s\n" |
|
"fmla v30.4s, v6.4s, v0.4s\n" |
|
"fmla v31.4s, v7.4s, v0.4s\n" |
|
|
|
|
|
"subs w6, w6, #1\n" |
|
"bne " LABEL_COL_LOOP "b\n" |
|
|
|
LABEL_SKIP_COL_LOOP |
|
":\n" |
|
|
|
|
|
"faddp v28.4s, v28.4s, v29.4s\n" |
|
"faddp v30.4s, v30.4s, v31.4s\n" |
|
"faddp v28.4s, v28.4s, v30.4s\n" |
|
|
|
|
|
"st1 {v28.4s}, [%[out_ptr]], #16\n" |
|
|
|
|
|
"subs %[assigned_rows], %[assigned_rows], #1\n" |
|
"bne " LABEL_ROW_LOOP "b\n" |
|
|
|
|
|
: |
|
[out_ptr] "+r"(out_ptr), |
|
[weights_ptr] "+r"(weights_ptr), |
|
[col_deltas_bytes] "+r"(col_deltas_bytes), |
|
[bias_ptr] "+r"(bias_ptr), |
|
[nnz_per_row] "+r"(nnz_per_row), |
|
[assigned_rows] "+r"(assigned_rows), |
|
[rhs_ptr] "+r"(rhs_ptr) |
|
: |
|
: |
|
"cc", "memory", "x6", "x7", "x8", "v0", "v1", "v2", "v3", "v4", "v5", |
|
"v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", |
|
"v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", |
|
"v26", "v27", "v28", "v29", "v30", "v31"); |
|
|
|
} |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template <typename WeightType, typename RhsType, typename OutType> |
|
typename std::enable_if<std::is_same<WeightType, bfloat16>::value && |
|
std::is_same<RhsType, float>::value && |
|
std::is_same<OutType, float>::value>::type |
|
SpMM5_4x4(const bfloat16* weights_ptr, const int16_t* col_deltas_bytes, |
|
const int32_t* nnz_per_row, const float* rhs_ptr, |
|
const float* bias_ptr, float* out_ptr, int64_t assigned_rows, |
|
int64_t rows, int64_t cols, int relu) { |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
const float* rhs2_ptr = rhs_ptr + cols; |
|
float* out2_ptr = out_ptr + rows; |
|
const float* rhs3_ptr = rhs_ptr + 2 * cols; |
|
float* out3_ptr = out_ptr + 2 * rows; |
|
const float* rhs4_ptr = rhs_ptr + 3 * cols; |
|
float* out4_ptr = out_ptr + 3 * rows; |
|
const float* rhs5_ptr = rhs_ptr + 4 * cols; |
|
float* out5_ptr = out_ptr + 4 * rows; |
|
if (relu) { |
|
asm( |
|
|
|
"ldrsh x7, [%[col_deltas_bytes]], #2\n" |
|
"ldrsh x8, [%[col_deltas_bytes]], #2\n" |
|
|
|
"add %[rhs_ptr], %[rhs_ptr], x7\n" |
|
"add %[rhs2_ptr], %[rhs2_ptr], x7\n" |
|
"add %[rhs3_ptr], %[rhs3_ptr], x7\n" |
|
"add %[rhs4_ptr], %[rhs4_ptr], x7\n" |
|
"add %[rhs5_ptr], %[rhs5_ptr], x7\n" |
|
|
|
LABEL_ROW_LOOP |
|
":\n" |
|
|
|
|
|
"ld1 {v27.4s}, [%[bias_ptr]], #16\n" |
|
|
|
|
|
"dup v28.4s, v27.s[0]\n" |
|
"dup v29.4s, v27.s[1]\n" |
|
"dup v30.4s, v27.s[2]\n" |
|
"dup v31.4s, v27.s[3]\n" |
|
"dup v23.4s, v27.s[0]\n" |
|
"dup v24.4s, v27.s[1]\n" |
|
"dup v25.4s, v27.s[2]\n" |
|
"dup v26.4s, v27.s[3]\n" |
|
"dup v19.4s, v27.s[0]\n" |
|
"dup v20.4s, v27.s[1]\n" |
|
"dup v21.4s, v27.s[2]\n" |
|
"dup v22.4s, v27.s[3]\n" |
|
"dup v15.4s, v27.s[0]\n" |
|
"dup v16.4s, v27.s[1]\n" |
|
"dup v17.4s, v27.s[2]\n" |
|
"dup v18.4s, v27.s[3]\n" |
|
"dup v11.4s, v27.s[0]\n" |
|
"dup v12.4s, v27.s[1]\n" |
|
"dup v13.4s, v27.s[2]\n" |
|
"dup v14.4s, v27.s[3]\n" |
|
|
|
|
|
"ldr w6, [%[nnz_per_row]], #4\n" |
|
"cmp w6, #0\n" |
|
|
|
"beq " LABEL_SKIP_COL_LOOP "f\n" |
|
|
|
LABEL_COL_LOOP |
|
":\n" |
|
|
|
"ld1 {v0.4s}, [%[rhs_ptr]], x8\n" |
|
"ld1 {v1.4s}, [%[rhs2_ptr]], x8\n" |
|
"ld1 {v8.4s}, [%[rhs3_ptr]], x8\n" |
|
"ld1 {v9.4s}, [%[rhs4_ptr]], x8\n" |
|
"ld1 {v10.4s}, [%[rhs5_ptr]], x8\n" |
|
|
|
|
|
"ldrsh x8, [%[col_deltas_bytes]], #2\n" |
|
|
|
|
|
"ld1 {v2.8h, v3.8h}, [%[weights_ptr]], #32\n" |
|
|
|
|
|
"shll v4.4s, v2.4h, #16\n" |
|
"shll2 v5.4s, v2.8h, #16\n" |
|
"shll v6.4s, v3.4h, #16\n" |
|
"shll2 v7.4s, v3.8h, #16\n" |
|
|
|
|
|
"fmla v28.4s, v4.4s, v0.4s\n" |
|
"fmla v29.4s, v5.4s, v0.4s\n" |
|
"fmla v30.4s, v6.4s, v0.4s\n" |
|
"fmla v31.4s, v7.4s, v0.4s\n" |
|
"fmla v23.4s, v4.4s, v1.4s\n" |
|
"fmla v24.4s, v5.4s, v1.4s\n" |
|
"fmla v25.4s, v6.4s, v1.4s\n" |
|
"fmla v26.4s, v7.4s, v1.4s\n" |
|
"fmla v19.4s, v4.4s, v8.4s\n" |
|
"fmla v20.4s, v5.4s, v8.4s\n" |
|
"fmla v21.4s, v6.4s, v8.4s\n" |
|
"fmla v22.4s, v7.4s, v8.4s\n" |
|
"fmla v15.4s, v4.4s, v9.4s\n" |
|
"fmla v16.4s, v5.4s, v9.4s\n" |
|
"fmla v17.4s, v6.4s, v9.4s\n" |
|
"fmla v18.4s, v7.4s, v9.4s\n" |
|
"fmla v11.4s, v4.4s, v10.4s\n" |
|
"fmla v12.4s, v5.4s, v10.4s\n" |
|
"fmla v13.4s, v6.4s, v10.4s\n" |
|
"fmla v14.4s, v7.4s, v10.4s\n" |
|
|
|
|
|
"subs w6, w6, #1\n" |
|
"bne " LABEL_COL_LOOP "b\n" |
|
|
|
LABEL_SKIP_COL_LOOP |
|
":\n" |
|
|
|
"movi v0.4s, #0\n" |
|
"faddp v28.4s, v28.4s, v29.4s\n" |
|
"faddp v23.4s, v23.4s, v24.4s\n" |
|
"faddp v19.4s, v19.4s, v20.4s\n" |
|
"faddp v15.4s, v15.4s, v16.4s\n" |
|
"faddp v11.4s, v11.4s, v12.4s\n" |
|
|
|
"faddp v30.4s, v30.4s, v31.4s\n" |
|
"faddp v25.4s, v25.4s, v26.4s\n" |
|
"faddp v21.4s, v21.4s, v22.4s\n" |
|
"faddp v17.4s, v17.4s, v18.4s\n" |
|
"faddp v13.4s, v13.4s, v14.4s\n" |
|
|
|
"faddp v28.4s, v28.4s, v30.4s\n" |
|
"faddp v23.4s, v23.4s, v25.4s\n" |
|
"faddp v19.4s, v19.4s, v21.4s\n" |
|
"faddp v15.4s, v15.4s, v17.4s\n" |
|
"faddp v11.4s, v11.4s, v13.4s\n" |
|
|
|
|
|
"fmax v28.4s, v28.4s, v0.4s\n" |
|
"fmax v23.4s, v23.4s, v0.4s\n" |
|
"fmax v19.4s, v19.4s, v0.4s\n" |
|
"fmax v15.4s, v15.4s, v0.4s\n" |
|
"fmax v11.4s, v11.4s, v0.4s\n" |
|
|
|
|
|
"st1 {v28.4s}, [%[out_ptr]], #16\n" |
|
"st1 {v23.4s}, [%[out2_ptr]], #16\n" |
|
"st1 {v19.4s}, [%[out3_ptr]], #16\n" |
|
"st1 {v15.4s}, [%[out4_ptr]], #16\n" |
|
"st1 {v11.4s}, [%[out5_ptr]], #16\n" |
|
|
|
|
|
"subs %[assigned_rows], %[assigned_rows], #1\n" |
|
"bne " LABEL_ROW_LOOP "b\n" |
|
|
|
|
|
: |
|
[out_ptr] "+r"(out_ptr), |
|
[out2_ptr] "+r"(out2_ptr), |
|
[out3_ptr] "+r"(out3_ptr), |
|
[out4_ptr] "+r"(out4_ptr), |
|
[out5_ptr] "+r"(out5_ptr), |
|
[weights_ptr] "+r"(weights_ptr), |
|
[col_deltas_bytes] "+r"(col_deltas_bytes), |
|
[bias_ptr] "+r"(bias_ptr), |
|
[nnz_per_row] "+r"(nnz_per_row), |
|
[assigned_rows] "+r"(assigned_rows), |
|
[rhs_ptr] "+r"(rhs_ptr), |
|
[rhs2_ptr] "+r"(rhs2_ptr), |
|
[rhs3_ptr] "+r"(rhs3_ptr), |
|
[rhs4_ptr] "+r"(rhs4_ptr), |
|
[rhs5_ptr] "+r"(rhs5_ptr) |
|
: |
|
: |
|
"cc", "memory", "x6", "x7", "x8", "v0", "v1", "v2", "v3", "v4", "v5", |
|
"v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", |
|
"v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", |
|
"v26", "v27", "v28", "v29", "v30", "v31"); |
|
|
|
} else { |
|
asm( |
|
|
|
"ldrsh x7, [%[col_deltas_bytes]], #2\n" |
|
"ldrsh x8, [%[col_deltas_bytes]], #2\n" |
|
|
|
"add %[rhs_ptr], %[rhs_ptr], x7\n" |
|
"add %[rhs2_ptr], %[rhs2_ptr], x7\n" |
|
"add %[rhs3_ptr], %[rhs3_ptr], x7\n" |
|
"add %[rhs4_ptr], %[rhs4_ptr], x7\n" |
|
"add %[rhs5_ptr], %[rhs5_ptr], x7\n" |
|
|
|
LABEL_ROW_LOOP |
|
":\n" |
|
|
|
|
|
"ld1 {v27.4s}, [%[bias_ptr]], #16\n" |
|
|
|
|
|
"dup v28.4s, v27.s[0]\n" |
|
"dup v29.4s, v27.s[1]\n" |
|
"dup v30.4s, v27.s[2]\n" |
|
"dup v31.4s, v27.s[3]\n" |
|
"dup v23.4s, v27.s[0]\n" |
|
"dup v24.4s, v27.s[1]\n" |
|
"dup v25.4s, v27.s[2]\n" |
|
"dup v26.4s, v27.s[3]\n" |
|
"dup v19.4s, v27.s[0]\n" |
|
"dup v20.4s, v27.s[1]\n" |
|
"dup v21.4s, v27.s[2]\n" |
|
"dup v22.4s, v27.s[3]\n" |
|
"dup v15.4s, v27.s[0]\n" |
|
"dup v16.4s, v27.s[1]\n" |
|
"dup v17.4s, v27.s[2]\n" |
|
"dup v18.4s, v27.s[3]\n" |
|
"dup v11.4s, v27.s[0]\n" |
|
"dup v12.4s, v27.s[1]\n" |
|
"dup v13.4s, v27.s[2]\n" |
|
"dup v14.4s, v27.s[3]\n" |
|
|
|
|
|
"ldr w6, [%[nnz_per_row]], #4\n" |
|
"cmp w6, #0\n" |
|
|
|
"beq " LABEL_SKIP_COL_LOOP "f\n" |
|
|
|
LABEL_COL_LOOP |
|
":\n" |
|
|
|
"ld1 {v0.4s}, [%[rhs_ptr]], x8\n" |
|
"ld1 {v1.4s}, [%[rhs2_ptr]], x8\n" |
|
"ld1 {v8.4s}, [%[rhs3_ptr]], x8\n" |
|
"ld1 {v9.4s}, [%[rhs4_ptr]], x8\n" |
|
"ld1 {v10.4s}, [%[rhs5_ptr]], x8\n" |
|
|
|
|
|
"ldrsh x8, [%[col_deltas_bytes]], #2\n" |
|
|
|
|
|
"ld1 {v2.8h, v3.8h}, [%[weights_ptr]], #32\n" |
|
|
|
|
|
"shll v4.4s, v2.4h, #16\n" |
|
"shll2 v5.4s, v2.8h, #16\n" |
|
"shll v6.4s, v3.4h, #16\n" |
|
"shll2 v7.4s, v3.8h, #16\n" |
|
|
|
|
|
"fmla v28.4s, v4.4s, v0.4s\n" |
|
"fmla v29.4s, v5.4s, v0.4s\n" |
|
"fmla v30.4s, v6.4s, v0.4s\n" |
|
"fmla v31.4s, v7.4s, v0.4s\n" |
|
"fmla v23.4s, v4.4s, v1.4s\n" |
|
"fmla v24.4s, v5.4s, v1.4s\n" |
|
"fmla v25.4s, v6.4s, v1.4s\n" |
|
"fmla v26.4s, v7.4s, v1.4s\n" |
|
"fmla v19.4s, v4.4s, v8.4s\n" |
|
"fmla v20.4s, v5.4s, v8.4s\n" |
|
"fmla v21.4s, v6.4s, v8.4s\n" |
|
"fmla v22.4s, v7.4s, v8.4s\n" |
|
"fmla v15.4s, v4.4s, v9.4s\n" |
|
"fmla v16.4s, v5.4s, v9.4s\n" |
|
"fmla v17.4s, v6.4s, v9.4s\n" |
|
"fmla v18.4s, v7.4s, v9.4s\n" |
|
"fmla v11.4s, v4.4s, v10.4s\n" |
|
"fmla v12.4s, v5.4s, v10.4s\n" |
|
"fmla v13.4s, v6.4s, v10.4s\n" |
|
"fmla v14.4s, v7.4s, v10.4s\n" |
|
|
|
|
|
"subs w6, w6, #1\n" |
|
"bne " LABEL_COL_LOOP "b\n" |
|
|
|
LABEL_SKIP_COL_LOOP |
|
":\n" |
|
|
|
|
|
"faddp v28.4s, v28.4s, v29.4s\n" |
|
"faddp v23.4s, v23.4s, v24.4s\n" |
|
"faddp v19.4s, v19.4s, v20.4s\n" |
|
"faddp v15.4s, v15.4s, v16.4s\n" |
|
"faddp v11.4s, v11.4s, v12.4s\n" |
|
|
|
"faddp v30.4s, v30.4s, v31.4s\n" |
|
"faddp v25.4s, v25.4s, v26.4s\n" |
|
"faddp v21.4s, v21.4s, v22.4s\n" |
|
"faddp v17.4s, v17.4s, v18.4s\n" |
|
"faddp v13.4s, v13.4s, v14.4s\n" |
|
|
|
"faddp v28.4s, v28.4s, v30.4s\n" |
|
"faddp v23.4s, v23.4s, v25.4s\n" |
|
"faddp v19.4s, v19.4s, v21.4s\n" |
|
"faddp v15.4s, v15.4s, v17.4s\n" |
|
"faddp v11.4s, v11.4s, v13.4s\n" |
|
|
|
|
|
"st1 {v28.4s}, [%[out_ptr]], #16\n" |
|
"st1 {v23.4s}, [%[out2_ptr]], #16\n" |
|
"st1 {v19.4s}, [%[out3_ptr]], #16\n" |
|
"st1 {v15.4s}, [%[out4_ptr]], #16\n" |
|
"st1 {v11.4s}, [%[out5_ptr]], #16\n" |
|
|
|
|
|
"subs %[assigned_rows], %[assigned_rows], #1\n" |
|
"bne " LABEL_ROW_LOOP "b\n" |
|
|
|
|
|
: |
|
[out_ptr] "+r"(out_ptr), |
|
[out2_ptr] "+r"(out2_ptr), |
|
[out3_ptr] "+r"(out3_ptr), |
|
[out4_ptr] "+r"(out4_ptr), |
|
[out5_ptr] "+r"(out5_ptr), |
|
[weights_ptr] "+r"(weights_ptr), |
|
[col_deltas_bytes] "+r"(col_deltas_bytes), |
|
[bias_ptr] "+r"(bias_ptr), |
|
[nnz_per_row] "+r"(nnz_per_row), |
|
[assigned_rows] "+r"(assigned_rows), |
|
[rhs_ptr] "+r"(rhs_ptr), |
|
[rhs2_ptr] "+r"(rhs2_ptr), |
|
[rhs3_ptr] "+r"(rhs3_ptr), |
|
[rhs4_ptr] "+r"(rhs4_ptr), |
|
[rhs5_ptr] "+r"(rhs5_ptr) |
|
: |
|
: |
|
"cc", "memory", "x6", "x7", "x8", "v0", "v1", "v2", "v3", "v4", "v5", |
|
"v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", |
|
"v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", |
|
"v26", "v27", "v28", "v29", "v30", "v31"); |
|
|
|
} |
|
} |
|
|
|
|
|
|
|
template <typename WeightType, typename RhsType, typename OutType> |
|
typename std::enable_if<std::is_same<WeightType, float>::value && |
|
std::is_same<RhsType, float>::value && |
|
std::is_same<OutType, float>::value>::type |
|
SpMV_4x4(const float* weights_ptr, const int16_t* col_deltas_bytes, |
|
const int32_t* nnz_per_row, const float* rhs_ptr, |
|
const float* bias_ptr, float* out_ptr, int64_t assigned_rows, |
|
int64_t rows , |
|
int64_t cols , int relu) { |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if (relu) { |
|
asm( |
|
|
|
"ldrsh x7, [%[col_deltas_bytes]], #2\n" |
|
"ldrsh x8, [%[col_deltas_bytes]], #2\n" |
|
|
|
"add %[rhs_ptr], %[rhs_ptr], x7\n" |
|
|
|
"movi v25.4s, #0\n" |
|
|
|
LABEL_ROW_LOOP |
|
":\n" |
|
|
|
|
|
"ld1 {v27.4s}, [%[bias_ptr]], #16\n" |
|
|
|
|
|
"dup v28.4s, v27.s[0]\n" |
|
"dup v29.4s, v27.s[1]\n" |
|
"dup v30.4s, v27.s[2]\n" |
|
"dup v31.4s, v27.s[3]\n" |
|
|
|
|
|
"ldr w6, [%[nnz_per_row]], #4\n" |
|
"cmp w6, #0\n" |
|
|
|
"beq " LABEL_SKIP_COL_LOOP "f\n" |
|
|
|
LABEL_COL_LOOP |
|
":\n" |
|
|
|
"ld1 {v0.4s}, [%[rhs_ptr]], x8\n" |
|
|
|
|
|
"ldrsh x8, [%[col_deltas_bytes]], #2\n" |
|
|
|
|
|
"ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [%[weights_ptr]], #64\n" |
|
|
|
|
|
"fmla v28.4s, v4.4s, v0.4s\n" |
|
"fmla v29.4s, v5.4s, v0.4s\n" |
|
"fmla v30.4s, v6.4s, v0.4s\n" |
|
"fmla v31.4s, v7.4s, v0.4s\n" |
|
|
|
|
|
"subs w6, w6, #1\n" |
|
"bne " LABEL_COL_LOOP "b\n" |
|
|
|
LABEL_SKIP_COL_LOOP |
|
":\n" |
|
|
|
|
|
"faddp v28.4s, v28.4s, v29.4s\n" |
|
"faddp v30.4s, v30.4s, v31.4s\n" |
|
"faddp v28.4s, v28.4s, v30.4s\n" |
|
|
|
|
|
"fmax v28.4s, v28.4s, v25.4s\n" |
|
|
|
|
|
"st1 {v28.4s}, [%[out_ptr]], #16\n" |
|
|
|
|
|
"subs %[assigned_rows], %[assigned_rows], #1\n" |
|
"bne " LABEL_ROW_LOOP "b\n" |
|
|
|
|
|
: |
|
[out_ptr] "+r"(out_ptr), |
|
[weights_ptr] "+r"(weights_ptr), |
|
[col_deltas_bytes] "+r"(col_deltas_bytes), |
|
[bias_ptr] "+r"(bias_ptr), |
|
[nnz_per_row] "+r"(nnz_per_row), |
|
[assigned_rows] "+r"(assigned_rows), |
|
[rhs_ptr] "+r"(rhs_ptr) |
|
: |
|
: |
|
"cc", "memory", "x6", "x7", "x8", "v0", "v1", "v2", "v3", "v4", "v5", |
|
"v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", |
|
"v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", |
|
"v26", "v27", "v28", "v29", "v30", "v31"); |
|
|
|
} else { |
|
asm( |
|
|
|
"ldrsh x7, [%[col_deltas_bytes]], #2\n" |
|
"ldrsh x8, [%[col_deltas_bytes]], #2\n" |
|
|
|
"add %[rhs_ptr], %[rhs_ptr], x7\n" |
|
|
|
LABEL_ROW_LOOP |
|
":\n" |
|
|
|
|
|
"ld1 {v27.4s}, [%[bias_ptr]], #16\n" |
|
|
|
|
|
"dup v28.4s, v27.s[0]\n" |
|
"dup v29.4s, v27.s[1]\n" |
|
"dup v30.4s, v27.s[2]\n" |
|
"dup v31.4s, v27.s[3]\n" |
|
|
|
|
|
"ldr w6, [%[nnz_per_row]], #4\n" |
|
"cmp w6, #0\n" |
|
|
|
"beq " LABEL_SKIP_COL_LOOP "f\n" |
|
|
|
LABEL_COL_LOOP |
|
":\n" |
|
|
|
"ld1 {v0.4s}, [%[rhs_ptr]], x8\n" |
|
|
|
|
|
"ldrsh x8, [%[col_deltas_bytes]], #2\n" |
|
|
|
|
|
"ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [%[weights_ptr]], #64\n" |
|
|
|
|
|
"fmla v28.4s, v4.4s, v0.4s\n" |
|
"fmla v29.4s, v5.4s, v0.4s\n" |
|
"fmla v30.4s, v6.4s, v0.4s\n" |
|
"fmla v31.4s, v7.4s, v0.4s\n" |
|
|
|
|
|
"subs w6, w6, #1\n" |
|
"bne " LABEL_COL_LOOP "b\n" |
|
|
|
LABEL_SKIP_COL_LOOP |
|
":\n" |
|
|
|
|
|
"faddp v28.4s, v28.4s, v29.4s\n" |
|
"faddp v30.4s, v30.4s, v31.4s\n" |
|
"faddp v28.4s, v28.4s, v30.4s\n" |
|
|
|
|
|
"st1 {v28.4s}, [%[out_ptr]], #16\n" |
|
|
|
|
|
"subs %[assigned_rows], %[assigned_rows], #1\n" |
|
"bne " LABEL_ROW_LOOP "b\n" |
|
|
|
|
|
: |
|
[out_ptr] "+r"(out_ptr), |
|
[weights_ptr] "+r"(weights_ptr), |
|
[col_deltas_bytes] "+r"(col_deltas_bytes), |
|
[bias_ptr] "+r"(bias_ptr), |
|
[nnz_per_row] "+r"(nnz_per_row), |
|
[assigned_rows] "+r"(assigned_rows), |
|
[rhs_ptr] "+r"(rhs_ptr) |
|
: |
|
: |
|
"cc", "memory", "x6", "x7", "x8", "v0", "v1", "v2", "v3", "v4", "v5", |
|
"v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", |
|
"v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", |
|
"v26", "v27", "v28", "v29", "v30", "v31"); |
|
|
|
} |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template <typename WeightType, typename RhsType, typename OutType> |
|
typename std::enable_if<std::is_same<WeightType, float>::value && |
|
std::is_same<RhsType, float>::value && |
|
std::is_same<OutType, float>::value>::type |
|
SpMM5_4x4(const float* weights_ptr, const int16_t* col_deltas_bytes, |
|
const int32_t* nnz_per_row, const float* rhs_ptr, |
|
const float* bias_ptr, float* out_ptr, int64_t assigned_rows, |
|
int64_t rows, int64_t cols, int relu) { |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
const float* rhs2_ptr = rhs_ptr + cols; |
|
float* out2_ptr = out_ptr + rows; |
|
const float* rhs3_ptr = rhs_ptr + 2 * cols; |
|
float* out3_ptr = out_ptr + 2 * rows; |
|
const float* rhs4_ptr = rhs_ptr + 3 * cols; |
|
float* out4_ptr = out_ptr + 3 * rows; |
|
const float* rhs5_ptr = rhs_ptr + 4 * cols; |
|
float* out5_ptr = out_ptr + 4 * rows; |
|
if (relu) { |
|
asm( |
|
|
|
"ldrsh x7, [%[col_deltas_bytes]], #2\n" |
|
"ldrsh x8, [%[col_deltas_bytes]], #2\n" |
|
|
|
"add %[rhs_ptr], %[rhs_ptr], x7\n" |
|
"add %[rhs2_ptr], %[rhs2_ptr], x7\n" |
|
"add %[rhs3_ptr], %[rhs3_ptr], x7\n" |
|
"add %[rhs4_ptr], %[rhs4_ptr], x7\n" |
|
"add %[rhs5_ptr], %[rhs5_ptr], x7\n" |
|
|
|
LABEL_ROW_LOOP |
|
":\n" |
|
|
|
"ld1 {v27.4s}, [%[bias_ptr]], #16\n" |
|
|
|
|
|
"dup v28.4s, v27.s[0]\n" |
|
"dup v29.4s, v27.s[1]\n" |
|
"dup v30.4s, v27.s[2]\n" |
|
"dup v31.4s, v27.s[3]\n" |
|
"dup v23.4s, v27.s[0]\n" |
|
"dup v24.4s, v27.s[1]\n" |
|
"dup v25.4s, v27.s[2]\n" |
|
"dup v26.4s, v27.s[3]\n" |
|
"dup v19.4s, v27.s[0]\n" |
|
"dup v20.4s, v27.s[1]\n" |
|
"dup v21.4s, v27.s[2]\n" |
|
"dup v22.4s, v27.s[3]\n" |
|
"dup v15.4s, v27.s[0]\n" |
|
"dup v16.4s, v27.s[1]\n" |
|
"dup v17.4s, v27.s[2]\n" |
|
"dup v18.4s, v27.s[3]\n" |
|
"dup v11.4s, v27.s[0]\n" |
|
"dup v12.4s, v27.s[1]\n" |
|
"dup v13.4s, v27.s[2]\n" |
|
"dup v14.4s, v27.s[3]\n" |
|
|
|
|
|
"ldr w6, [%[nnz_per_row]], #4\n" |
|
"cmp w6, #0\n" |
|
|
|
"beq " LABEL_SKIP_COL_LOOP "f\n" |
|
|
|
LABEL_COL_LOOP |
|
":\n" |
|
|
|
"ld1 {v0.4s}, [%[rhs_ptr]], x8\n" |
|
"ld1 {v1.4s}, [%[rhs2_ptr]], x8\n" |
|
"ld1 {v8.4s}, [%[rhs3_ptr]], x8\n" |
|
"ld1 {v9.4s}, [%[rhs4_ptr]], x8\n" |
|
"ld1 {v10.4s}, [%[rhs5_ptr]], x8\n" |
|
|
|
|
|
"ldrsh x8, [%[col_deltas_bytes]], #2\n" |
|
|
|
|
|
"ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [%[weights_ptr]], #64\n" |
|
|
|
|
|
"fmla v28.4s, v4.4s, v0.4s\n" |
|
"fmla v29.4s, v5.4s, v0.4s\n" |
|
"fmla v30.4s, v6.4s, v0.4s\n" |
|
"fmla v31.4s, v7.4s, v0.4s\n" |
|
"fmla v23.4s, v4.4s, v1.4s\n" |
|
"fmla v24.4s, v5.4s, v1.4s\n" |
|
"fmla v25.4s, v6.4s, v1.4s\n" |
|
"fmla v26.4s, v7.4s, v1.4s\n" |
|
"fmla v19.4s, v4.4s, v8.4s\n" |
|
"fmla v20.4s, v5.4s, v8.4s\n" |
|
"fmla v21.4s, v6.4s, v8.4s\n" |
|
"fmla v22.4s, v7.4s, v8.4s\n" |
|
"fmla v15.4s, v4.4s, v9.4s\n" |
|
"fmla v16.4s, v5.4s, v9.4s\n" |
|
"fmla v17.4s, v6.4s, v9.4s\n" |
|
"fmla v18.4s, v7.4s, v9.4s\n" |
|
"fmla v11.4s, v4.4s, v10.4s\n" |
|
"fmla v12.4s, v5.4s, v10.4s\n" |
|
"fmla v13.4s, v6.4s, v10.4s\n" |
|
"fmla v14.4s, v7.4s, v10.4s\n" |
|
|
|
|
|
"subs w6, w6, #1\n" |
|
"bne " LABEL_COL_LOOP "b\n" |
|
|
|
LABEL_SKIP_COL_LOOP |
|
":\n" |
|
|
|
"movi v0.4s, #0\n" |
|
"faddp v28.4s, v28.4s, v29.4s\n" |
|
"faddp v23.4s, v23.4s, v24.4s\n" |
|
"faddp v19.4s, v19.4s, v20.4s\n" |
|
"faddp v15.4s, v15.4s, v16.4s\n" |
|
"faddp v11.4s, v11.4s, v12.4s\n" |
|
|
|
"faddp v30.4s, v30.4s, v31.4s\n" |
|
"faddp v25.4s, v25.4s, v26.4s\n" |
|
"faddp v21.4s, v21.4s, v22.4s\n" |
|
"faddp v17.4s, v17.4s, v18.4s\n" |
|
"faddp v13.4s, v13.4s, v14.4s\n" |
|
|
|
"faddp v28.4s, v28.4s, v30.4s\n" |
|
"faddp v23.4s, v23.4s, v25.4s\n" |
|
"faddp v19.4s, v19.4s, v21.4s\n" |
|
"faddp v15.4s, v15.4s, v17.4s\n" |
|
"faddp v11.4s, v11.4s, v13.4s\n" |
|
|
|
|
|
"fmax v28.4s, v28.4s, v0.4s\n" |
|
"fmax v23.4s, v23.4s, v0.4s\n" |
|
"fmax v19.4s, v19.4s, v0.4s\n" |
|
"fmax v15.4s, v15.4s, v0.4s\n" |
|
"fmax v11.4s, v11.4s, v0.4s\n" |
|
|
|
|
|
"st1 {v28.4s}, [%[out_ptr]], #16\n" |
|
"st1 {v23.4s}, [%[out2_ptr]], #16\n" |
|
"st1 {v19.4s}, [%[out3_ptr]], #16\n" |
|
"st1 {v15.4s}, [%[out4_ptr]], #16\n" |
|
"st1 {v11.4s}, [%[out5_ptr]], #16\n" |
|
|
|
|
|
"subs %[assigned_rows], %[assigned_rows], #1\n" |
|
"bne " LABEL_ROW_LOOP "b\n" |
|
|
|
|
|
: |
|
[out_ptr] "+r"(out_ptr), |
|
[out2_ptr] "+r"(out2_ptr), |
|
[out3_ptr] "+r"(out3_ptr), |
|
[out4_ptr] "+r"(out4_ptr), |
|
[out5_ptr] "+r"(out5_ptr), |
|
[weights_ptr] "+r"(weights_ptr), |
|
[col_deltas_bytes] "+r"(col_deltas_bytes), |
|
[bias_ptr] "+r"(bias_ptr), |
|
[nnz_per_row] "+r"(nnz_per_row), |
|
[assigned_rows] "+r"(assigned_rows), |
|
[rhs_ptr] "+r"(rhs_ptr), |
|
[rhs2_ptr] "+r"(rhs2_ptr), |
|
[rhs3_ptr] "+r"(rhs3_ptr), |
|
[rhs4_ptr] "+r"(rhs4_ptr), |
|
[rhs5_ptr] "+r"(rhs5_ptr) |
|
: |
|
: |
|
"cc", "memory", "x6", "x7", "x8", "v0", "v1", "v2", "v3", "v4", "v5", |
|
"v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", |
|
"v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", |
|
"v26", "v27", "v28", "v29", "v30", "v31"); |
|
|
|
} else { |
|
asm( |
|
|
|
"ldrsh x7, [%[col_deltas_bytes]], #2\n" |
|
"ldrsh x8, [%[col_deltas_bytes]], #2\n" |
|
|
|
"add %[rhs_ptr], %[rhs_ptr], x7\n" |
|
"add %[rhs2_ptr], %[rhs2_ptr], x7\n" |
|
"add %[rhs3_ptr], %[rhs3_ptr], x7\n" |
|
"add %[rhs4_ptr], %[rhs4_ptr], x7\n" |
|
"add %[rhs5_ptr], %[rhs5_ptr], x7\n" |
|
|
|
LABEL_ROW_LOOP |
|
":\n" |
|
|
|
|
|
"ld1 {v27.4s}, [%[bias_ptr]], #16\n" |
|
|
|
|
|
"dup v28.4s, v27.s[0]\n" |
|
"dup v29.4s, v27.s[1]\n" |
|
"dup v30.4s, v27.s[2]\n" |
|
"dup v31.4s, v27.s[3]\n" |
|
"dup v23.4s, v27.s[0]\n" |
|
"dup v24.4s, v27.s[1]\n" |
|
"dup v25.4s, v27.s[2]\n" |
|
"dup v26.4s, v27.s[3]\n" |
|
"dup v19.4s, v27.s[0]\n" |
|
"dup v20.4s, v27.s[1]\n" |
|
"dup v21.4s, v27.s[2]\n" |
|
"dup v22.4s, v27.s[3]\n" |
|
"dup v15.4s, v27.s[0]\n" |
|
"dup v16.4s, v27.s[1]\n" |
|
"dup v17.4s, v27.s[2]\n" |
|
"dup v18.4s, v27.s[3]\n" |
|
"dup v11.4s, v27.s[0]\n" |
|
"dup v12.4s, v27.s[1]\n" |
|
"dup v13.4s, v27.s[2]\n" |
|
"dup v14.4s, v27.s[3]\n" |
|
|
|
|
|
"ldr w6, [%[nnz_per_row]], #4\n" |
|
"cmp w6, #0\n" |
|
|
|
"beq " LABEL_SKIP_COL_LOOP "f\n" |
|
|
|
LABEL_COL_LOOP |
|
":\n" |
|
|
|
"ld1 {v0.4s}, [%[rhs_ptr]], x8\n" |
|
"ld1 {v1.4s}, [%[rhs2_ptr]], x8\n" |
|
"ld1 {v8.4s}, [%[rhs3_ptr]], x8\n" |
|
"ld1 {v9.4s}, [%[rhs4_ptr]], x8\n" |
|
"ld1 {v10.4s}, [%[rhs5_ptr]], x8\n" |
|
|
|
|
|
"ldrsh x8, [%[col_deltas_bytes]], #2\n" |
|
|
|
|
|
"ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [%[weights_ptr]], #64\n" |
|
|
|
|
|
"fmla v28.4s, v4.4s, v0.4s\n" |
|
"fmla v29.4s, v5.4s, v0.4s\n" |
|
"fmla v30.4s, v6.4s, v0.4s\n" |
|
"fmla v31.4s, v7.4s, v0.4s\n" |
|
"fmla v23.4s, v4.4s, v1.4s\n" |
|
"fmla v24.4s, v5.4s, v1.4s\n" |
|
"fmla v25.4s, v6.4s, v1.4s\n" |
|
"fmla v26.4s, v7.4s, v1.4s\n" |
|
"fmla v19.4s, v4.4s, v8.4s\n" |
|
"fmla v20.4s, v5.4s, v8.4s\n" |
|
"fmla v21.4s, v6.4s, v8.4s\n" |
|
"fmla v22.4s, v7.4s, v8.4s\n" |
|
"fmla v15.4s, v4.4s, v9.4s\n" |
|
"fmla v16.4s, v5.4s, v9.4s\n" |
|
"fmla v17.4s, v6.4s, v9.4s\n" |
|
"fmla v18.4s, v7.4s, v9.4s\n" |
|
"fmla v11.4s, v4.4s, v10.4s\n" |
|
"fmla v12.4s, v5.4s, v10.4s\n" |
|
"fmla v13.4s, v6.4s, v10.4s\n" |
|
"fmla v14.4s, v7.4s, v10.4s\n" |
|
|
|
|
|
"subs w6, w6, #1\n" |
|
"bne " LABEL_COL_LOOP "b\n" |
|
|
|
LABEL_SKIP_COL_LOOP |
|
":\n" |
|
|
|
|
|
"faddp v28.4s, v28.4s, v29.4s\n" |
|
"faddp v23.4s, v23.4s, v24.4s\n" |
|
"faddp v19.4s, v19.4s, v20.4s\n" |
|
"faddp v15.4s, v15.4s, v16.4s\n" |
|
"faddp v11.4s, v11.4s, v12.4s\n" |
|
|
|
"faddp v30.4s, v30.4s, v31.4s\n" |
|
"faddp v25.4s, v25.4s, v26.4s\n" |
|
"faddp v21.4s, v21.4s, v22.4s\n" |
|
"faddp v17.4s, v17.4s, v18.4s\n" |
|
"faddp v13.4s, v13.4s, v14.4s\n" |
|
|
|
"faddp v28.4s, v28.4s, v30.4s\n" |
|
"faddp v23.4s, v23.4s, v25.4s\n" |
|
"faddp v19.4s, v19.4s, v21.4s\n" |
|
"faddp v15.4s, v15.4s, v17.4s\n" |
|
"faddp v11.4s, v11.4s, v13.4s\n" |
|
|
|
|
|
"st1 {v28.4s}, [%[out_ptr]], #16\n" |
|
"st1 {v23.4s}, [%[out2_ptr]], #16\n" |
|
"st1 {v19.4s}, [%[out3_ptr]], #16\n" |
|
"st1 {v15.4s}, [%[out4_ptr]], #16\n" |
|
"st1 {v11.4s}, [%[out5_ptr]], #16\n" |
|
|
|
|
|
"subs %[assigned_rows], %[assigned_rows], #1\n" |
|
"bne " LABEL_ROW_LOOP "b\n" |
|
|
|
|
|
: |
|
[out_ptr] "+r"(out_ptr), |
|
[out2_ptr] "+r"(out2_ptr), |
|
[out3_ptr] "+r"(out3_ptr), |
|
[out4_ptr] "+r"(out4_ptr), |
|
[out5_ptr] "+r"(out5_ptr), |
|
[weights_ptr] "+r"(weights_ptr), |
|
[col_deltas_bytes] "+r"(col_deltas_bytes), |
|
[bias_ptr] "+r"(bias_ptr), |
|
[nnz_per_row] "+r"(nnz_per_row), |
|
[assigned_rows] "+r"(assigned_rows), |
|
[rhs_ptr] "+r"(rhs_ptr), |
|
[rhs2_ptr] "+r"(rhs2_ptr), |
|
[rhs3_ptr] "+r"(rhs3_ptr), |
|
[rhs4_ptr] "+r"(rhs4_ptr), |
|
[rhs5_ptr] "+r"(rhs5_ptr) |
|
: |
|
: |
|
"cc", "memory", "x6", "x7", "x8", "v0", "v1", "v2", "v3", "v4", "v5", |
|
"v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", |
|
"v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", |
|
"v26", "v27", "v28", "v29", "v30", "v31"); |
|
|
|
} |
|
} |
|
|
|
|
|
|
|
template <typename WeightType, typename RhsType, typename OutType> |
|
typename std::enable_if< |
|
IsFixed16Type<WeightType>::value && IsFixed16Type<RhsType>::value && |
|
std::is_same<OutType, typename TypeOfProduct<WeightType, |
|
RhsType>::type>::value>::type |
|
SpMV_4x4(const WeightType* weights_ptr, const int16_t* col_deltas_bytes, |
|
const int32_t* nnz_per_row, const RhsType* rhs_ptr, |
|
const typename TypeOfProduct<WeightType, RhsType>::type* bias_ptr, |
|
OutType* out_ptr, int64_t assigned_rows, |
|
int64_t rows , |
|
int64_t cols , int relu) { |
|
if (relu) { |
|
asm( |
|
|
|
"ldrsh x7, [%[col_deltas_bytes]], #2\n" |
|
"ldrsh x8, [%[col_deltas_bytes]], #2\n" |
|
|
|
"add %[rhs_ptr], %[rhs_ptr], x7\n" |
|
|
|
"movi v25.4s, #0\n" |
|
|
|
LABEL_ROW_LOOP |
|
":\n" |
|
|
|
|
|
"ld1 {v27.4s}, [%[bias_ptr]], #16\n" |
|
|
|
|
|
"dup v28.4s, v27.s[0]\n" |
|
"dup v29.4s, v27.s[1]\n" |
|
"dup v30.4s, v27.s[2]\n" |
|
"dup v31.4s, v27.s[3]\n" |
|
|
|
|
|
"ldr w6, [%[nnz_per_row]], #4\n" |
|
"cmp w6, #0\n" |
|
|
|
"beq " LABEL_SKIP_COL_LOOP "f\n" |
|
|
|
LABEL_COL_LOOP |
|
":\n" |
|
|
|
"ld1 {v0.4h}, [%[rhs_ptr]], x8\n" |
|
|
|
"mov v0.d[1], v0.d[0]\n" |
|
|
|
|
|
"ldrsh x8, [%[col_deltas_bytes]], #2\n" |
|
|
|
|
|
"ld1 {v2.8h, v3.8h}, [%[weights_ptr]], #32\n" |
|
|
|
|
|
"smlal v28.4s, v2.4h, v0.4h\n" |
|
"smlal2 v29.4s, v2.8h, v0.8h\n" |
|
"smlal v30.4s, v3.4h, v0.4h\n" |
|
"smlal2 v31.4s, v3.8h, v0.8h\n" |
|
|
|
|
|
"subs w6, w6, #1\n" |
|
"bne " LABEL_COL_LOOP "b\n" |
|
|
|
LABEL_SKIP_COL_LOOP |
|
":\n" |
|
|
|
|
|
"addp v28.4s, v28.4s, v29.4s\n" |
|
"addp v30.4s, v30.4s, v31.4s\n" |
|
"addp v28.4s, v28.4s, v30.4s\n" |
|
|
|
|
|
"smax v28.4s, v28.4s, v25.4s\n" |
|
|
|
|
|
"st1 {v28.4s}, [%[out_ptr]], #16\n" |
|
|
|
|
|
"subs %[assigned_rows], %[assigned_rows], #1\n" |
|
"bne " LABEL_ROW_LOOP "b\n" |
|
|
|
|
|
: |
|
[out_ptr] "+r"(out_ptr), [weights_ptr] "+r"(weights_ptr), |
|
[col_deltas_bytes] "+r"(col_deltas_bytes), [bias_ptr] "+r"(bias_ptr), |
|
[nnz_per_row] "+r"(nnz_per_row), [assigned_rows] "+r"(assigned_rows), |
|
[rhs_ptr] "+r"(rhs_ptr) |
|
: |
|
: |
|
"cc", "memory", "x6", "x7", "x8", "v0", "v1", "v2", "v3", "v4", "v5", |
|
"v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", |
|
"v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", |
|
"v26", "v27", "v28", "v29", "v30", "v31"); |
|
|
|
} else { |
|
asm( |
|
|
|
"ldrsh x7, [%[col_deltas_bytes]], #2\n" |
|
"ldrsh x8, [%[col_deltas_bytes]], #2\n" |
|
|
|
"add %[rhs_ptr], %[rhs_ptr], x7\n" |
|
|
|
"movi v25.4s, #0\n" |
|
|
|
LABEL_ROW_LOOP |
|
":\n" |
|
|
|
|
|
"ld1 {v27.4s}, [%[bias_ptr]], #16\n" |
|
|
|
|
|
"dup v28.4s, v27.s[0]\n" |
|
"dup v29.4s, v27.s[1]\n" |
|
"dup v30.4s, v27.s[2]\n" |
|
"dup v31.4s, v27.s[3]\n" |
|
|
|
|
|
"ldr w6, [%[nnz_per_row]], #4\n" |
|
"cmp w6, #0\n" |
|
|
|
"beq " LABEL_SKIP_COL_LOOP "f\n" |
|
|
|
LABEL_COL_LOOP |
|
":\n" |
|
|
|
"ld1 {v0.4h}, [%[rhs_ptr]], x8\n" |
|
|
|
"mov v0.d[1], v0.d[0]\n" |
|
|
|
|
|
"ldrsh x8, [%[col_deltas_bytes]], #2\n" |
|
|
|
|
|
"ld1 {v2.8h, v3.8h}, [%[weights_ptr]], #32\n" |
|
|
|
|
|
"smlal v28.4s, v2.4h, v0.4h\n" |
|
"smlal2 v29.4s, v2.8h, v0.8h\n" |
|
"smlal v30.4s, v3.4h, v0.4h\n" |
|
"smlal2 v31.4s, v3.8h, v0.8h\n" |
|
|
|
|
|
"subs w6, w6, #1\n" |
|
"bne " LABEL_COL_LOOP "b\n" |
|
|
|
LABEL_SKIP_COL_LOOP |
|
":\n" |
|
|
|
|
|
"addp v28.4s, v28.4s, v29.4s\n" |
|
"addp v30.4s, v30.4s, v31.4s\n" |
|
"addp v28.4s, v28.4s, v30.4s\n" |
|
|
|
|
|
"st1 {v28.4s}, [%[out_ptr]], #16\n" |
|
|
|
|
|
"subs %[assigned_rows], %[assigned_rows], #1\n" |
|
"bne " LABEL_ROW_LOOP "b\n" |
|
|
|
|
|
: |
|
[out_ptr] "+r"(out_ptr), [weights_ptr] "+r"(weights_ptr), |
|
[col_deltas_bytes] "+r"(col_deltas_bytes), [bias_ptr] "+r"(bias_ptr), |
|
[nnz_per_row] "+r"(nnz_per_row), [assigned_rows] "+r"(assigned_rows), |
|
[rhs_ptr] "+r"(rhs_ptr) |
|
: |
|
: |
|
"cc", "memory", "x6", "x7", "x8", "v0", "v1", "v2", "v3", "v4", "v5", |
|
"v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", |
|
"v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", |
|
"v26", "v27", "v28", "v29", "v30", "v31"); |
|
|
|
} |
|
} |
|
|
|
|
|
|
|
template <typename WeightType, typename RhsType, typename OutType> |
|
typename std::enable_if< |
|
IsFixed16Type<WeightType>::value && IsFixed16Type<RhsType>::value && |
|
std::is_same<OutType, typename TypeOfProduct<WeightType, |
|
RhsType>::type>::value>::type |
|
SpMM5_4x4(const WeightType* weights_ptr, const int16_t* col_deltas_bytes, |
|
const int32_t* nnz_per_row, const RhsType* rhs_ptr, |
|
const typename TypeOfProduct<WeightType, RhsType>::type* bias_ptr, |
|
OutType* out_ptr, int64_t assigned_rows, int64_t rows, int64_t cols, |
|
int relu) { |
|
|
|
const RhsType* rhs2_ptr = rhs_ptr + cols; |
|
OutType* out2_ptr = out_ptr + rows; |
|
const RhsType* rhs3_ptr = rhs_ptr + 2 * cols; |
|
OutType* out3_ptr = out_ptr + 2 * rows; |
|
const RhsType* rhs4_ptr = rhs_ptr + 3 * cols; |
|
OutType* out4_ptr = out_ptr + 3 * rows; |
|
const RhsType* rhs5_ptr = rhs_ptr + 4 * cols; |
|
OutType* out5_ptr = out_ptr + 4 * rows; |
|
if (relu) { |
|
asm( |
|
|
|
"ldrsh x7, [%[col_deltas_bytes]], #2\n" |
|
"ldrsh x8, [%[col_deltas_bytes]], #2\n" |
|
|
|
"add %[rhs_ptr], %[rhs_ptr], x7\n" |
|
"add %[rhs2_ptr], %[rhs2_ptr], x7\n" |
|
"add %[rhs3_ptr], %[rhs3_ptr], x7\n" |
|
"add %[rhs4_ptr], %[rhs4_ptr], x7\n" |
|
"add %[rhs5_ptr], %[rhs5_ptr], x7\n" |
|
|
|
LABEL_ROW_LOOP |
|
":\n" |
|
|
|
|
|
"ld1 {v27.4s}, [%[bias_ptr]], #16\n" |
|
|
|
|
|
"dup v28.4s, v27.s[0]\n" |
|
"dup v29.4s, v27.s[1]\n" |
|
"dup v30.4s, v27.s[2]\n" |
|
"dup v31.4s, v27.s[3]\n" |
|
"dup v23.4s, v27.s[0]\n" |
|
"dup v24.4s, v27.s[1]\n" |
|
"dup v25.4s, v27.s[2]\n" |
|
"dup v26.4s, v27.s[3]\n" |
|
"dup v19.4s, v27.s[0]\n" |
|
"dup v20.4s, v27.s[1]\n" |
|
"dup v21.4s, v27.s[2]\n" |
|
"dup v22.4s, v27.s[3]\n" |
|
"dup v15.4s, v27.s[0]\n" |
|
"dup v16.4s, v27.s[1]\n" |
|
"dup v17.4s, v27.s[2]\n" |
|
"dup v18.4s, v27.s[3]\n" |
|
"dup v11.4s, v27.s[0]\n" |
|
"dup v12.4s, v27.s[1]\n" |
|
"dup v13.4s, v27.s[2]\n" |
|
"dup v14.4s, v27.s[3]\n" |
|
|
|
|
|
"ldr w6, [%[nnz_per_row]], #4\n" |
|
"cmp w6, #0\n" |
|
|
|
"beq " LABEL_SKIP_COL_LOOP "f\n" |
|
|
|
LABEL_COL_LOOP |
|
":\n" |
|
|
|
"ld1 {v0.4h}, [%[rhs_ptr]], x8\n" |
|
"mov v0.d[1], v0.d[0]\n" |
|
"ld1 {v1.4h}, [%[rhs2_ptr]], x8\n" |
|
"mov v1.d[1], v1.d[0]\n" |
|
"ld1 {v8.4h}, [%[rhs3_ptr]], x8\n" |
|
"mov v8.d[1], v8.d[0]\n" |
|
"ld1 {v9.4h}, [%[rhs4_ptr]], x8\n" |
|
"mov v9.d[1], v9.d[0]\n" |
|
"ld1 {v10.4h}, [%[rhs5_ptr]], x8\n" |
|
"mov v10.d[1], v10.d[0]\n" |
|
|
|
|
|
"ldrsh x8, [%[col_deltas_bytes]], #2\n" |
|
|
|
|
|
"ld1 {v2.8h, v3.8h}, [%[weights_ptr]], #32\n" |
|
|
|
|
|
"smlal v28.4s, v2.4h, v0.4h\n" |
|
"smlal2 v29.4s, v2.8h, v0.8h\n" |
|
"smlal v30.4s, v3.4h, v0.4h\n" |
|
"smlal2 v31.4s, v3.8h, v0.8h\n" |
|
"smlal v23.4s, v2.4h, v1.4h\n" |
|
"smlal2 v24.4s, v2.8h, v1.8h\n" |
|
"smlal v25.4s, v3.4h, v1.4h\n" |
|
"smlal2 v26.4s, v3.8h, v1.8h\n" |
|
"smlal v19.4s, v2.4h, v8.4h\n" |
|
"smlal2 v20.4s, v2.8h, v8.8h\n" |
|
"smlal v21.4s, v3.4h, v8.4h\n" |
|
"smlal2 v22.4s, v3.8h, v8.8h\n" |
|
"smlal v15.4s, v2.4h, v9.4h\n" |
|
"smlal2 v16.4s, v2.8h, v9.8h\n" |
|
"smlal v17.4s, v3.4h, v9.4h\n" |
|
"smlal2 v18.4s, v3.8h, v9.8h\n" |
|
"smlal v11.4s, v2.4h, v10.4h\n" |
|
"smlal2 v12.4s, v2.8h, v10.8h\n" |
|
"smlal v13.4s, v3.4h, v10.4h\n" |
|
"smlal2 v14.4s, v3.8h, v10.8h\n" |
|
|
|
|
|
"subs w6, w6, #1\n" |
|
"bne " LABEL_COL_LOOP "b\n" |
|
|
|
LABEL_SKIP_COL_LOOP |
|
":\n" |
|
|
|
"movi v0.4s, #0\n" |
|
"addp v28.4s, v28.4s, v29.4s\n" |
|
"addp v23.4s, v23.4s, v24.4s\n" |
|
"addp v19.4s, v19.4s, v20.4s\n" |
|
"addp v15.4s, v15.4s, v16.4s\n" |
|
"addp v11.4s, v11.4s, v12.4s\n" |
|
|
|
"addp v30.4s, v30.4s, v31.4s\n" |
|
"addp v25.4s, v25.4s, v26.4s\n" |
|
"addp v21.4s, v21.4s, v22.4s\n" |
|
"addp v17.4s, v17.4s, v18.4s\n" |
|
"addp v13.4s, v13.4s, v14.4s\n" |
|
|
|
"addp v28.4s, v28.4s, v30.4s\n" |
|
"addp v23.4s, v23.4s, v25.4s\n" |
|
"addp v19.4s, v19.4s, v21.4s\n" |
|
"addp v15.4s, v15.4s, v17.4s\n" |
|
"addp v11.4s, v11.4s, v13.4s\n" |
|
|
|
|
|
"smax v28.4s, v28.4s, v0.4s\n" |
|
"smax v23.4s, v23.4s, v0.4s\n" |
|
"smax v19.4s, v19.4s, v0.4s\n" |
|
"smax v15.4s, v15.4s, v0.4s\n" |
|
"smax v11.4s, v11.4s, v0.4s\n" |
|
|
|
|
|
"st1 {v28.4s}, [%[out_ptr]], #16\n" |
|
"st1 {v23.4s}, [%[out2_ptr]], #16\n" |
|
"st1 {v19.4s}, [%[out3_ptr]], #16\n" |
|
"st1 {v15.4s}, [%[out4_ptr]], #16\n" |
|
"st1 {v11.4s}, [%[out5_ptr]], #16\n" |
|
|
|
|
|
"subs %[assigned_rows], %[assigned_rows], #1\n" |
|
"bne " LABEL_ROW_LOOP "b\n" |
|
|
|
|
|
: |
|
[out_ptr] "+r"(out_ptr), [out2_ptr] "+r"(out2_ptr), |
|
[out3_ptr] "+r"(out3_ptr), [out4_ptr] "+r"(out4_ptr), |
|
[out5_ptr] "+r"(out5_ptr), [weights_ptr] "+r"(weights_ptr), |
|
[col_deltas_bytes] "+r"(col_deltas_bytes), [bias_ptr] "+r"(bias_ptr), |
|
[nnz_per_row] "+r"(nnz_per_row), [assigned_rows] "+r"(assigned_rows), |
|
[rhs_ptr] "+r"(rhs_ptr), [rhs2_ptr] "+r"(rhs2_ptr), |
|
[rhs3_ptr] "+r"(rhs3_ptr), [rhs4_ptr] "+r"(rhs4_ptr), |
|
[rhs5_ptr] "+r"(rhs5_ptr) |
|
: |
|
: |
|
"cc", "memory", "x6", "x7", "x8", "v0", "v1", "v2", "v3", "v4", "v5", |
|
"v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", |
|
"v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", |
|
"v26", "v27", "v28", "v29", "v30", "v31"); |
|
|
|
} else { |
|
asm( |
|
|
|
"ldrsh x7, [%[col_deltas_bytes]], #2\n" |
|
"ldrsh x8, [%[col_deltas_bytes]], #2\n" |
|
|
|
"add %[rhs_ptr], %[rhs_ptr], x7\n" |
|
"add %[rhs2_ptr], %[rhs2_ptr], x7\n" |
|
"add %[rhs3_ptr], %[rhs3_ptr], x7\n" |
|
"add %[rhs4_ptr], %[rhs4_ptr], x7\n" |
|
"add %[rhs5_ptr], %[rhs5_ptr], x7\n" |
|
|
|
LABEL_ROW_LOOP |
|
":\n" |
|
|
|
|
|
"ld1 {v27.4s}, [%[bias_ptr]], #16\n" |
|
|
|
|
|
"dup v28.4s, v27.s[0]\n" |
|
"dup v29.4s, v27.s[1]\n" |
|
"dup v30.4s, v27.s[2]\n" |
|
"dup v31.4s, v27.s[3]\n" |
|
"dup v23.4s, v27.s[0]\n" |
|
"dup v24.4s, v27.s[1]\n" |
|
"dup v25.4s, v27.s[2]\n" |
|
"dup v26.4s, v27.s[3]\n" |
|
"dup v19.4s, v27.s[0]\n" |
|
"dup v20.4s, v27.s[1]\n" |
|
"dup v21.4s, v27.s[2]\n" |
|
"dup v22.4s, v27.s[3]\n" |
|
"dup v15.4s, v27.s[0]\n" |
|
"dup v16.4s, v27.s[1]\n" |
|
"dup v17.4s, v27.s[2]\n" |
|
"dup v18.4s, v27.s[3]\n" |
|
"dup v11.4s, v27.s[0]\n" |
|
"dup v12.4s, v27.s[1]\n" |
|
"dup v13.4s, v27.s[2]\n" |
|
"dup v14.4s, v27.s[3]\n" |
|
|
|
|
|
"ldr w6, [%[nnz_per_row]], #4\n" |
|
"cmp w6, #0\n" |
|
|
|
"beq " LABEL_SKIP_COL_LOOP "f\n" |
|
|
|
LABEL_COL_LOOP |
|
":\n" |
|
|
|
"ld1 {v0.4h}, [%[rhs_ptr]], x8\n" |
|
"mov v0.d[1], v0.d[0]\n" |
|
"ld1 {v1.4h}, [%[rhs2_ptr]], x8\n" |
|
"mov v1.d[1], v1.d[0]\n" |
|
"ld1 {v8.4h}, [%[rhs3_ptr]], x8\n" |
|
"mov v8.d[1], v8.d[0]\n" |
|
"ld1 {v9.4h}, [%[rhs4_ptr]], x8\n" |
|
"mov v9.d[1], v9.d[0]\n" |
|
"ld1 {v10.4h}, [%[rhs5_ptr]], x8\n" |
|
"mov v10.d[1], v10.d[0]\n" |
|
|
|
|
|
"ldrsh x8, [%[col_deltas_bytes]], #2\n" |
|
|
|
|
|
"ld1 {v2.8h, v3.8h}, [%[weights_ptr]], #32\n" |
|
|
|
|
|
"smlal v28.4s, v2.4h, v0.4h\n" |
|
"smlal2 v29.4s, v2.8h, v0.8h\n" |
|
"smlal v30.4s, v3.4h, v0.4h\n" |
|
"smlal2 v31.4s, v3.8h, v0.8h\n" |
|
"smlal v23.4s, v2.4h, v1.4h\n" |
|
"smlal2 v24.4s, v2.8h, v1.8h\n" |
|
"smlal v25.4s, v3.4h, v1.4h\n" |
|
"smlal2 v26.4s, v3.8h, v1.8h\n" |
|
"smlal v19.4s, v2.4h, v8.4h\n" |
|
"smlal2 v20.4s, v2.8h, v8.8h\n" |
|
"smlal v21.4s, v3.4h, v8.4h\n" |
|
"smlal2 v22.4s, v3.8h, v8.8h\n" |
|
"smlal v15.4s, v2.4h, v9.4h\n" |
|
"smlal2 v16.4s, v2.8h, v9.8h\n" |
|
"smlal v17.4s, v3.4h, v9.4h\n" |
|
"smlal2 v18.4s, v3.8h, v9.8h\n" |
|
"smlal v11.4s, v2.4h, v10.4h\n" |
|
"smlal2 v12.4s, v2.8h, v10.8h\n" |
|
"smlal v13.4s, v3.4h, v10.4h\n" |
|
"smlal2 v14.4s, v3.8h, v10.8h\n" |
|
|
|
|
|
"subs w6, w6, #1\n" |
|
"bne " LABEL_COL_LOOP "b\n" |
|
|
|
LABEL_SKIP_COL_LOOP |
|
":\n" |
|
|
|
"addp v28.4s, v28.4s, v29.4s\n" |
|
"addp v23.4s, v23.4s, v24.4s\n" |
|
"addp v19.4s, v19.4s, v20.4s\n" |
|
"addp v15.4s, v15.4s, v16.4s\n" |
|
"addp v11.4s, v11.4s, v12.4s\n" |
|
|
|
"addp v30.4s, v30.4s, v31.4s\n" |
|
"addp v25.4s, v25.4s, v26.4s\n" |
|
"addp v21.4s, v21.4s, v22.4s\n" |
|
"addp v17.4s, v17.4s, v18.4s\n" |
|
"addp v13.4s, v13.4s, v14.4s\n" |
|
|
|
"addp v28.4s, v28.4s, v30.4s\n" |
|
"addp v23.4s, v23.4s, v25.4s\n" |
|
"addp v19.4s, v19.4s, v21.4s\n" |
|
"addp v15.4s, v15.4s, v17.4s\n" |
|
"addp v11.4s, v11.4s, v13.4s\n" |
|
|
|
|
|
"st1 {v28.4s}, [%[out_ptr]], #16\n" |
|
"st1 {v23.4s}, [%[out2_ptr]], #16\n" |
|
"st1 {v19.4s}, [%[out3_ptr]], #16\n" |
|
"st1 {v15.4s}, [%[out4_ptr]], #16\n" |
|
"st1 {v11.4s}, [%[out5_ptr]], #16\n" |
|
|
|
|
|
"subs %[assigned_rows], %[assigned_rows], #1\n" |
|
"bne " LABEL_ROW_LOOP "b\n" |
|
|
|
|
|
: |
|
[out_ptr] "+r"(out_ptr), [out2_ptr] "+r"(out2_ptr), |
|
[out3_ptr] "+r"(out3_ptr), [out4_ptr] "+r"(out4_ptr), |
|
[out5_ptr] "+r"(out5_ptr), [weights_ptr] "+r"(weights_ptr), |
|
[col_deltas_bytes] "+r"(col_deltas_bytes), [bias_ptr] "+r"(bias_ptr), |
|
[nnz_per_row] "+r"(nnz_per_row), [assigned_rows] "+r"(assigned_rows), |
|
[rhs_ptr] "+r"(rhs_ptr), [rhs2_ptr] "+r"(rhs2_ptr), |
|
[rhs3_ptr] "+r"(rhs3_ptr), [rhs4_ptr] "+r"(rhs4_ptr), |
|
[rhs5_ptr] "+r"(rhs5_ptr) |
|
: |
|
: |
|
"cc", "memory", "x6", "x7", "x8", "v0", "v1", "v2", "v3", "v4", "v5", |
|
"v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", |
|
"v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", |
|
"v26", "v27", "v28", "v29", "v30", "v31"); |
|
|
|
} |
|
} |
|
|
|
|
|
|
|
template <typename WeightType, typename RhsType, typename OutType> |
|
typename std::enable_if<IsFixed16Type<WeightType>::value && |
|
IsFixed16Type<RhsType>::value && |
|
IsFixed16Type<OutType>::value>::type |
|
SpMV_4x4(const WeightType* weights_ptr, const int16_t* col_deltas_bytes, |
|
const int32_t* nnz_per_row, const RhsType* rhs_ptr, |
|
const typename TypeOfProduct<WeightType, RhsType>::type* bias_ptr, |
|
OutType* out_ptr, int64_t assigned_rows, |
|
int64_t rows , |
|
int64_t cols , int relu) { |
|
constexpr int kShiftAmount = 15 - WeightType::kExponentBits - |
|
RhsType::kExponentBits + OutType::kExponentBits; |
|
if (relu) { |
|
asm( |
|
|
|
"ldrsh x7, [%[col_deltas_bytes]], #2\n" |
|
"ldrsh x8, [%[col_deltas_bytes]], #2\n" |
|
|
|
"add %[rhs_ptr], %[rhs_ptr], x7\n" |
|
|
|
"movi v25.4s, #0\n" |
|
|
|
LABEL_ROW_LOOP |
|
":\n" |
|
|
|
|
|
"ld1 {v27.4s}, [%[bias_ptr]], #16\n" |
|
|
|
|
|
"dup v28.4s, v27.s[0]\n" |
|
"dup v29.4s, v27.s[1]\n" |
|
"dup v30.4s, v27.s[2]\n" |
|
"dup v31.4s, v27.s[3]\n" |
|
|
|
|
|
"ldr w6, [%[nnz_per_row]], #4\n" |
|
"cmp w6, #0\n" |
|
|
|
"beq " LABEL_SKIP_COL_LOOP "f\n" |
|
|
|
LABEL_COL_LOOP |
|
":\n" |
|
|
|
"ld1 {v0.4h}, [%[rhs_ptr]], x8\n" |
|
|
|
"mov v0.d[1], v0.d[0]\n" |
|
|
|
|
|
"ldrsh x8, [%[col_deltas_bytes]], #2\n" |
|
|
|
|
|
"ld1 {v2.8h, v3.8h}, [%[weights_ptr]], #32\n" |
|
|
|
|
|
"smlal v28.4s, v2.4h, v0.4h\n" |
|
"smlal2 v29.4s, v2.8h, v0.8h\n" |
|
"smlal v30.4s, v3.4h, v0.4h\n" |
|
"smlal2 v31.4s, v3.8h, v0.8h\n" |
|
|
|
|
|
"subs w6, w6, #1\n" |
|
"bne " LABEL_COL_LOOP "b\n" |
|
|
|
LABEL_SKIP_COL_LOOP |
|
":\n" |
|
|
|
|
|
"addp v28.4s, v28.4s, v29.4s\n" |
|
"addp v30.4s, v30.4s, v31.4s\n" |
|
"addp v28.4s, v28.4s, v30.4s\n" |
|
|
|
|
|
"smax v28.4s, v28.4s, v25.4s\n" |
|
"sqrshrn v26.4h, v28.4s, %[shift_amount]\n" |
|
|
|
|
|
"st1 {v26.4h}, [%[out_ptr]], #8\n" |
|
|
|
|
|
"subs %[assigned_rows], %[assigned_rows], #1\n" |
|
"bne " LABEL_ROW_LOOP "b\n" |
|
|
|
|
|
: |
|
[out_ptr] "+r"(out_ptr), [weights_ptr] "+r"(weights_ptr), |
|
[col_deltas_bytes] "+r"(col_deltas_bytes), [bias_ptr] "+r"(bias_ptr), |
|
[nnz_per_row] "+r"(nnz_per_row), [assigned_rows] "+r"(assigned_rows), |
|
[rhs_ptr] "+r"(rhs_ptr) |
|
: |
|
[shift_amount] "I"(kShiftAmount) |
|
: |
|
"cc", "memory", "x6", "x7", "x8", "v0", "v1", "v2", "v3", "v4", "v5", |
|
"v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", |
|
"v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", |
|
"v26", "v27", "v28", "v29", "v30", "v31"); |
|
|
|
} else { |
|
asm( |
|
|
|
"ldrsh x7, [%[col_deltas_bytes]], #2\n" |
|
"ldrsh x8, [%[col_deltas_bytes]], #2\n" |
|
|
|
"add %[rhs_ptr], %[rhs_ptr], x7\n" |
|
|
|
"movi v25.4s, #0\n" |
|
|
|
LABEL_ROW_LOOP |
|
":\n" |
|
|
|
|
|
"ld1 {v27.4s}, [%[bias_ptr]], #16\n" |
|
|
|
|
|
"dup v28.4s, v27.s[0]\n" |
|
"dup v29.4s, v27.s[1]\n" |
|
"dup v30.4s, v27.s[2]\n" |
|
"dup v31.4s, v27.s[3]\n" |
|
|
|
|
|
"ldr w6, [%[nnz_per_row]], #4\n" |
|
"cmp w6, #0\n" |
|
|
|
"beq " LABEL_SKIP_COL_LOOP "f\n" |
|
|
|
LABEL_COL_LOOP |
|
":\n" |
|
|
|
"ld1 {v0.4h}, [%[rhs_ptr]], x8\n" |
|
|
|
"mov v0.d[1], v0.d[0]\n" |
|
|
|
|
|
"ldrsh x8, [%[col_deltas_bytes]], #2\n" |
|
|
|
|
|
"ld1 {v2.8h, v3.8h}, [%[weights_ptr]], #32\n" |
|
|
|
|
|
"smlal v28.4s, v2.4h, v0.4h\n" |
|
"smlal2 v29.4s, v2.8h, v0.8h\n" |
|
"smlal v30.4s, v3.4h, v0.4h\n" |
|
"smlal2 v31.4s, v3.8h, v0.8h\n" |
|
|
|
|
|
"subs w6, w6, #1\n" |
|
"bne " LABEL_COL_LOOP "b\n" |
|
|
|
LABEL_SKIP_COL_LOOP |
|
":\n" |
|
|
|
|
|
"addp v28.4s, v28.4s, v29.4s\n" |
|
"addp v30.4s, v30.4s, v31.4s\n" |
|
"addp v28.4s, v28.4s, v30.4s\n" |
|
"sqrshrn v26.4h, v28.4s, %[shift_amount]\n" |
|
|
|
|
|
"st1 {v26.4h}, [%[out_ptr]], #8\n" |
|
|
|
|
|
"subs %[assigned_rows], %[assigned_rows], #1\n" |
|
"bne " LABEL_ROW_LOOP "b\n" |
|
|
|
|
|
: |
|
[out_ptr] "+r"(out_ptr), [weights_ptr] "+r"(weights_ptr), |
|
[col_deltas_bytes] "+r"(col_deltas_bytes), [bias_ptr] "+r"(bias_ptr), |
|
[nnz_per_row] "+r"(nnz_per_row), [assigned_rows] "+r"(assigned_rows), |
|
[rhs_ptr] "+r"(rhs_ptr) |
|
: |
|
[shift_amount] "I"(kShiftAmount) |
|
: |
|
"cc", "memory", "x6", "x7", "x8", "v0", "v1", "v2", "v3", "v4", "v5", |
|
"v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", |
|
"v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", |
|
"v26", "v27", "v28", "v29", "v30", "v31"); |
|
|
|
} |
|
} |
|
|
|
|
|
|
|
template <typename WeightType, typename RhsType, typename OutType> |
|
typename std::enable_if<IsFixed16Type<WeightType>::value && |
|
IsFixed16Type<RhsType>::value && |
|
IsFixed16Type<OutType>::value>::type |
|
SpMM5_4x4(const WeightType* weights_ptr, const int16_t* col_deltas_bytes, |
|
const int32_t* nnz_per_row, const RhsType* rhs_ptr, |
|
const typename TypeOfProduct<WeightType, RhsType>::type* bias_ptr, |
|
OutType* out_ptr, int64_t assigned_rows, int64_t rows, int64_t cols, |
|
int relu) { |
|
constexpr int kShiftAmount = 15 - WeightType::kExponentBits - |
|
RhsType::kExponentBits + OutType::kExponentBits; |
|
|
|
const RhsType* rhs2_ptr = rhs_ptr + cols; |
|
OutType* out2_ptr = out_ptr + rows; |
|
const RhsType* rhs3_ptr = rhs_ptr + 2 * cols; |
|
OutType* out3_ptr = out_ptr + 2 * rows; |
|
const RhsType* rhs4_ptr = rhs_ptr + 3 * cols; |
|
OutType* out4_ptr = out_ptr + 3 * rows; |
|
const RhsType* rhs5_ptr = rhs_ptr + 4 * cols; |
|
OutType* out5_ptr = out_ptr + 4 * rows; |
|
if (relu) { |
|
asm( |
|
|
|
"ldrsh x7, [%[col_deltas_bytes]], #2\n" |
|
"ldrsh x8, [%[col_deltas_bytes]], #2\n" |
|
|
|
"add %[rhs_ptr], %[rhs_ptr], x7\n" |
|
"add %[rhs2_ptr], %[rhs2_ptr], x7\n" |
|
"add %[rhs3_ptr], %[rhs3_ptr], x7\n" |
|
"add %[rhs4_ptr], %[rhs4_ptr], x7\n" |
|
"add %[rhs5_ptr], %[rhs5_ptr], x7\n" |
|
|
|
LABEL_ROW_LOOP |
|
":\n" |
|
|
|
|
|
"ld1 {v27.4s}, [%[bias_ptr]], #16\n" |
|
|
|
|
|
"dup v28.4s, v27.s[0]\n" |
|
"dup v29.4s, v27.s[1]\n" |
|
"dup v30.4s, v27.s[2]\n" |
|
"dup v31.4s, v27.s[3]\n" |
|
"dup v23.4s, v27.s[0]\n" |
|
"dup v24.4s, v27.s[1]\n" |
|
"dup v25.4s, v27.s[2]\n" |
|
"dup v26.4s, v27.s[3]\n" |
|
"dup v19.4s, v27.s[0]\n" |
|
"dup v20.4s, v27.s[1]\n" |
|
"dup v21.4s, v27.s[2]\n" |
|
"dup v22.4s, v27.s[3]\n" |
|
"dup v15.4s, v27.s[0]\n" |
|
"dup v16.4s, v27.s[1]\n" |
|
"dup v17.4s, v27.s[2]\n" |
|
"dup v18.4s, v27.s[3]\n" |
|
"dup v11.4s, v27.s[0]\n" |
|
"dup v12.4s, v27.s[1]\n" |
|
"dup v13.4s, v27.s[2]\n" |
|
"dup v14.4s, v27.s[3]\n" |
|
|
|
|
|
"ldr w6, [%[nnz_per_row]], #4\n" |
|
"cmp w6, #0\n" |
|
|
|
"beq " LABEL_SKIP_COL_LOOP "f\n" |
|
|
|
LABEL_COL_LOOP |
|
":\n" |
|
|
|
"ld1 {v0.4h}, [%[rhs_ptr]], x8\n" |
|
"mov v0.d[1], v0.d[0]\n" |
|
"ld1 {v1.4h}, [%[rhs2_ptr]], x8\n" |
|
"mov v1.d[1], v1.d[0]\n" |
|
"ld1 {v8.4h}, [%[rhs3_ptr]], x8\n" |
|
"mov v8.d[1], v8.d[0]\n" |
|
"ld1 {v9.4h}, [%[rhs4_ptr]], x8\n" |
|
"mov v9.d[1], v9.d[0]\n" |
|
"ld1 {v10.4h}, [%[rhs5_ptr]], x8\n" |
|
"mov v10.d[1], v10.d[0]\n" |
|
|
|
|
|
"ldrsh x8, [%[col_deltas_bytes]], #2\n" |
|
|
|
|
|
"ld1 {v2.8h, v3.8h}, [%[weights_ptr]], #32\n" |
|
|
|
|
|
"smlal v28.4s, v2.4h, v0.4h\n" |
|
"smlal2 v29.4s, v2.8h, v0.8h\n" |
|
"smlal v30.4s, v3.4h, v0.4h\n" |
|
"smlal2 v31.4s, v3.8h, v0.8h\n" |
|
"smlal v23.4s, v2.4h, v1.4h\n" |
|
"smlal2 v24.4s, v2.8h, v1.8h\n" |
|
"smlal v25.4s, v3.4h, v1.4h\n" |
|
"smlal2 v26.4s, v3.8h, v1.8h\n" |
|
"smlal v19.4s, v2.4h, v8.4h\n" |
|
"smlal2 v20.4s, v2.8h, v8.8h\n" |
|
"smlal v21.4s, v3.4h, v8.4h\n" |
|
"smlal2 v22.4s, v3.8h, v8.8h\n" |
|
"smlal v15.4s, v2.4h, v9.4h\n" |
|
"smlal2 v16.4s, v2.8h, v9.8h\n" |
|
"smlal v17.4s, v3.4h, v9.4h\n" |
|
"smlal2 v18.4s, v3.8h, v9.8h\n" |
|
"smlal v11.4s, v2.4h, v10.4h\n" |
|
"smlal2 v12.4s, v2.8h, v10.8h\n" |
|
"smlal v13.4s, v3.4h, v10.4h\n" |
|
"smlal2 v14.4s, v3.8h, v10.8h\n" |
|
|
|
|
|
"subs w6, w6, #1\n" |
|
"bne " LABEL_COL_LOOP "b\n" |
|
|
|
LABEL_SKIP_COL_LOOP |
|
":\n" |
|
|
|
"movi v0.4s, #0\n" |
|
"addp v28.4s, v28.4s, v29.4s\n" |
|
"addp v23.4s, v23.4s, v24.4s\n" |
|
"addp v19.4s, v19.4s, v20.4s\n" |
|
"addp v15.4s, v15.4s, v16.4s\n" |
|
"addp v11.4s, v11.4s, v12.4s\n" |
|
|
|
"addp v30.4s, v30.4s, v31.4s\n" |
|
"addp v25.4s, v25.4s, v26.4s\n" |
|
"addp v21.4s, v21.4s, v22.4s\n" |
|
"addp v17.4s, v17.4s, v18.4s\n" |
|
"addp v13.4s, v13.4s, v14.4s\n" |
|
|
|
"addp v28.4s, v28.4s, v30.4s\n" |
|
"addp v23.4s, v23.4s, v25.4s\n" |
|
"addp v19.4s, v19.4s, v21.4s\n" |
|
"addp v15.4s, v15.4s, v17.4s\n" |
|
"addp v11.4s, v11.4s, v13.4s\n" |
|
|
|
|
|
"smax v28.4s, v28.4s, v0.4s\n" |
|
"smax v23.4s, v23.4s, v0.4s\n" |
|
"smax v19.4s, v19.4s, v0.4s\n" |
|
"smax v15.4s, v15.4s, v0.4s\n" |
|
"smax v11.4s, v11.4s, v0.4s\n" |
|
"sqrshrn v26.4h, v28.4s, %[shift_amount]\n" |
|
"sqrshrn v22.4h, v23.4s, %[shift_amount]\n" |
|
"sqrshrn v18.4h, v19.4s, %[shift_amount]\n" |
|
"sqrshrn v14.4h, v15.4s, %[shift_amount]\n" |
|
"sqrshrn v10.4h, v11.4s, %[shift_amount]\n" |
|
|
|
|
|
"st1 {v26.4h}, [%[out_ptr]], #8\n" |
|
"st1 {v22.4h}, [%[out2_ptr]], #8\n" |
|
"st1 {v18.4h}, [%[out3_ptr]], #8\n" |
|
"st1 {v14.4h}, [%[out4_ptr]], #8\n" |
|
"st1 {v10.4h}, [%[out5_ptr]], #8\n" |
|
|
|
|
|
"subs %[assigned_rows], %[assigned_rows], #1\n" |
|
"bne " LABEL_ROW_LOOP "b\n" |
|
|
|
|
|
: |
|
[out_ptr] "+r"(out_ptr), [out2_ptr] "+r"(out2_ptr), |
|
[out3_ptr] "+r"(out3_ptr), [out4_ptr] "+r"(out4_ptr), |
|
[out5_ptr] "+r"(out5_ptr), [weights_ptr] "+r"(weights_ptr), |
|
[col_deltas_bytes] "+r"(col_deltas_bytes), [bias_ptr] "+r"(bias_ptr), |
|
[nnz_per_row] "+r"(nnz_per_row), [assigned_rows] "+r"(assigned_rows), |
|
[rhs_ptr] "+r"(rhs_ptr), [rhs2_ptr] "+r"(rhs2_ptr), |
|
[rhs3_ptr] "+r"(rhs3_ptr), [rhs4_ptr] "+r"(rhs4_ptr), |
|
[rhs5_ptr] "+r"(rhs5_ptr) |
|
: |
|
[shift_amount] "I"(kShiftAmount) |
|
: |
|
"cc", "memory", "x6", "x7", "x8", "v0", "v1", "v2", "v3", "v4", "v5", |
|
"v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", |
|
"v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", |
|
"v26", "v27", "v28", "v29", "v30", "v31"); |
|
|
|
} else { |
|
asm( |
|
|
|
"ldrsh x7, [%[col_deltas_bytes]], #2\n" |
|
"ldrsh x8, [%[col_deltas_bytes]], #2\n" |
|
|
|
"add %[rhs_ptr], %[rhs_ptr], x7\n" |
|
"add %[rhs2_ptr], %[rhs2_ptr], x7\n" |
|
"add %[rhs3_ptr], %[rhs3_ptr], x7\n" |
|
"add %[rhs4_ptr], %[rhs4_ptr], x7\n" |
|
"add %[rhs5_ptr], %[rhs5_ptr], x7\n" |
|
|
|
LABEL_ROW_LOOP |
|
":\n" |
|
|
|
|
|
"ld1 {v27.4s}, [%[bias_ptr]], #16\n" |
|
|
|
|
|
"dup v28.4s, v27.s[0]\n" |
|
"dup v29.4s, v27.s[1]\n" |
|
"dup v30.4s, v27.s[2]\n" |
|
"dup v31.4s, v27.s[3]\n" |
|
"dup v23.4s, v27.s[0]\n" |
|
"dup v24.4s, v27.s[1]\n" |
|
"dup v25.4s, v27.s[2]\n" |
|
"dup v26.4s, v27.s[3]\n" |
|
"dup v19.4s, v27.s[0]\n" |
|
"dup v20.4s, v27.s[1]\n" |
|
"dup v21.4s, v27.s[2]\n" |
|
"dup v22.4s, v27.s[3]\n" |
|
"dup v15.4s, v27.s[0]\n" |
|
"dup v16.4s, v27.s[1]\n" |
|
"dup v17.4s, v27.s[2]\n" |
|
"dup v18.4s, v27.s[3]\n" |
|
"dup v11.4s, v27.s[0]\n" |
|
"dup v12.4s, v27.s[1]\n" |
|
"dup v13.4s, v27.s[2]\n" |
|
"dup v14.4s, v27.s[3]\n" |
|
|
|
|
|
"ldr w6, [%[nnz_per_row]], #4\n" |
|
"cmp w6, #0\n" |
|
|
|
"beq " LABEL_SKIP_COL_LOOP "f\n" |
|
|
|
LABEL_COL_LOOP |
|
":\n" |
|
|
|
"ld1 {v0.4h}, [%[rhs_ptr]], x8\n" |
|
"mov v0.d[1], v0.d[0]\n" |
|
"ld1 {v1.4h}, [%[rhs2_ptr]], x8\n" |
|
"mov v1.d[1], v1.d[0]\n" |
|
"ld1 {v8.4h}, [%[rhs3_ptr]], x8\n" |
|
"mov v8.d[1], v8.d[0]\n" |
|
"ld1 {v9.4h}, [%[rhs4_ptr]], x8\n" |
|
"mov v9.d[1], v9.d[0]\n" |
|
"ld1 {v10.4h}, [%[rhs5_ptr]], x8\n" |
|
"mov v10.d[1], v10.d[0]\n" |
|
|
|
|
|
"ldrsh x8, [%[col_deltas_bytes]], #2\n" |
|
|
|
|
|
"ld1 {v2.8h, v3.8h}, [%[weights_ptr]], #32\n" |
|
|
|
|
|
"smlal v28.4s, v2.4h, v0.4h\n" |
|
"smlal2 v29.4s, v2.8h, v0.8h\n" |
|
"smlal v30.4s, v3.4h, v0.4h\n" |
|
"smlal2 v31.4s, v3.8h, v0.8h\n" |
|
"smlal v23.4s, v2.4h, v1.4h\n" |
|
"smlal2 v24.4s, v2.8h, v1.8h\n" |
|
"smlal v25.4s, v3.4h, v1.4h\n" |
|
"smlal2 v26.4s, v3.8h, v1.8h\n" |
|
"smlal v19.4s, v2.4h, v8.4h\n" |
|
"smlal2 v20.4s, v2.8h, v8.8h\n" |
|
"smlal v21.4s, v3.4h, v8.4h\n" |
|
"smlal2 v22.4s, v3.8h, v8.8h\n" |
|
"smlal v15.4s, v2.4h, v9.4h\n" |
|
"smlal2 v16.4s, v2.8h, v9.8h\n" |
|
"smlal v17.4s, v3.4h, v9.4h\n" |
|
"smlal2 v18.4s, v3.8h, v9.8h\n" |
|
"smlal v11.4s, v2.4h, v10.4h\n" |
|
"smlal2 v12.4s, v2.8h, v10.8h\n" |
|
"smlal v13.4s, v3.4h, v10.4h\n" |
|
"smlal2 v14.4s, v3.8h, v10.8h\n" |
|
|
|
|
|
"subs w6, w6, #1\n" |
|
"bne " LABEL_COL_LOOP "b\n" |
|
|
|
LABEL_SKIP_COL_LOOP |
|
":\n" |
|
|
|
"addp v28.4s, v28.4s, v29.4s\n" |
|
"addp v23.4s, v23.4s, v24.4s\n" |
|
"addp v19.4s, v19.4s, v20.4s\n" |
|
"addp v15.4s, v15.4s, v16.4s\n" |
|
"addp v11.4s, v11.4s, v12.4s\n" |
|
|
|
"addp v30.4s, v30.4s, v31.4s\n" |
|
"addp v25.4s, v25.4s, v26.4s\n" |
|
"addp v21.4s, v21.4s, v22.4s\n" |
|
"addp v17.4s, v17.4s, v18.4s\n" |
|
"addp v13.4s, v13.4s, v14.4s\n" |
|
|
|
"addp v28.4s, v28.4s, v30.4s\n" |
|
"addp v23.4s, v23.4s, v25.4s\n" |
|
"addp v19.4s, v19.4s, v21.4s\n" |
|
"addp v15.4s, v15.4s, v17.4s\n" |
|
"addp v11.4s, v11.4s, v13.4s\n" |
|
|
|
"sqrshrn v26.4h, v28.4s, %[shift_amount]\n" |
|
"sqrshrn v22.4h, v23.4s, %[shift_amount]\n" |
|
"sqrshrn v18.4h, v19.4s, %[shift_amount]\n" |
|
"sqrshrn v14.4h, v15.4s, %[shift_amount]\n" |
|
"sqrshrn v10.4h, v11.4s, %[shift_amount]\n" |
|
|
|
|
|
"st1 {v26.4h}, [%[out_ptr]], #8\n" |
|
"st1 {v22.4h}, [%[out2_ptr]], #8\n" |
|
"st1 {v18.4h}, [%[out3_ptr]], #8\n" |
|
"st1 {v14.4h}, [%[out4_ptr]], #8\n" |
|
"st1 {v10.4h}, [%[out5_ptr]], #8\n" |
|
|
|
|
|
"subs %[assigned_rows], %[assigned_rows], #1\n" |
|
"bne " LABEL_ROW_LOOP "b\n" |
|
|
|
|
|
: |
|
[out_ptr] "+r"(out_ptr), [out2_ptr] "+r"(out2_ptr), |
|
[out3_ptr] "+r"(out3_ptr), [out4_ptr] "+r"(out4_ptr), |
|
[out5_ptr] "+r"(out5_ptr), [weights_ptr] "+r"(weights_ptr), |
|
[col_deltas_bytes] "+r"(col_deltas_bytes), [bias_ptr] "+r"(bias_ptr), |
|
[nnz_per_row] "+r"(nnz_per_row), [assigned_rows] "+r"(assigned_rows), |
|
[rhs_ptr] "+r"(rhs_ptr), [rhs2_ptr] "+r"(rhs2_ptr), |
|
[rhs3_ptr] "+r"(rhs3_ptr), [rhs4_ptr] "+r"(rhs4_ptr), |
|
[rhs5_ptr] "+r"(rhs5_ptr) |
|
: |
|
[shift_amount] "I"(kShiftAmount) |
|
: |
|
"cc", "memory", "x6", "x7", "x8", "v0", "v1", "v2", "v3", "v4", "v5", |
|
"v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", |
|
"v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", |
|
"v26", "v27", "v28", "v29", "v30", "v31"); |
|
|
|
} |
|
} |
|
|
|
|
|
|
|
template <typename WeightType, typename RhsType, typename OutType> |
|
typename std::enable_if< |
|
IsFixed16Type<WeightType>::value && IsFixed16Type<RhsType>::value && |
|
IsFixed32Type<OutType>::value && |
|
!std::is_same<OutType, typename TypeOfProduct<WeightType, |
|
RhsType>::type>::value>::type |
|
SpMV_4x4(const WeightType* weights_ptr, const int16_t* col_deltas_bytes, |
|
const int32_t* nnz_per_row, const RhsType* rhs_ptr, |
|
const typename TypeOfProduct<WeightType, RhsType>::type* bias_ptr, |
|
OutType* out_ptr, int64_t assigned_rows, |
|
int64_t rows , |
|
int64_t cols , int relu) { |
|
constexpr int kShiftAmount = |
|
TypeOfProduct<WeightType, RhsType>::type::kMantissaBits - |
|
OutType::kMantissaBits; |
|
static_assert(kShiftAmount > 0, |
|
"Result must have fewer mantissa bits than product"); |
|
if (relu) { |
|
asm( |
|
|
|
"ldrsh x7, [%[col_deltas_bytes]], #2\n" |
|
"ldrsh x8, [%[col_deltas_bytes]], #2\n" |
|
|
|
"add %[rhs_ptr], %[rhs_ptr], x7\n" |
|
|
|
"movi v25.4s, #0\n" |
|
|
|
LABEL_ROW_LOOP |
|
":\n" |
|
|
|
|
|
"ld1 {v27.4s}, [%[bias_ptr]], #16\n" |
|
|
|
|
|
"dup v28.4s, v27.s[0]\n" |
|
"dup v29.4s, v27.s[1]\n" |
|
"dup v30.4s, v27.s[2]\n" |
|
"dup v31.4s, v27.s[3]\n" |
|
|
|
|
|
"ldr w6, [%[nnz_per_row]], #4\n" |
|
"cmp w6, #0\n" |
|
|
|
"beq " LABEL_SKIP_COL_LOOP "f\n" |
|
|
|
LABEL_COL_LOOP |
|
":\n" |
|
|
|
"ld1 {v0.4h}, [%[rhs_ptr]], x8\n" |
|
|
|
"mov v0.d[1], v0.d[0]\n" |
|
|
|
|
|
"ldrsh x8, [%[col_deltas_bytes]], #2\n" |
|
|
|
|
|
"ld1 {v2.8h, v3.8h}, [%[weights_ptr]], #32\n" |
|
|
|
|
|
"smlal v28.4s, v2.4h, v0.4h\n" |
|
"smlal2 v29.4s, v2.8h, v0.8h\n" |
|
"smlal v30.4s, v3.4h, v0.4h\n" |
|
"smlal2 v31.4s, v3.8h, v0.8h\n" |
|
|
|
|
|
"subs w6, w6, #1\n" |
|
"bne " LABEL_COL_LOOP "b\n" |
|
|
|
LABEL_SKIP_COL_LOOP |
|
":\n" |
|
|
|
|
|
"addp v28.4s, v28.4s, v29.4s\n" |
|
"addp v30.4s, v30.4s, v31.4s\n" |
|
"addp v28.4s, v28.4s, v30.4s\n" |
|
|
|
|
|
"smax v28.4s, v28.4s, v25.4s\n" |
|
"srshr v28.4s, v28.4s, %[shift_amount]\n" |
|
|
|
|
|
"st1 {v28.4s}, [%[out_ptr]], #16\n" |
|
|
|
|
|
"subs %[assigned_rows], %[assigned_rows], #1\n" |
|
"bne " LABEL_ROW_LOOP "b\n" |
|
|
|
|
|
: |
|
[out_ptr] "+r"(out_ptr), [weights_ptr] "+r"(weights_ptr), |
|
[col_deltas_bytes] "+r"(col_deltas_bytes), [bias_ptr] "+r"(bias_ptr), |
|
[nnz_per_row] "+r"(nnz_per_row), [assigned_rows] "+r"(assigned_rows), |
|
[rhs_ptr] "+r"(rhs_ptr) |
|
: |
|
[shift_amount] "I"(kShiftAmount) |
|
: |
|
"cc", "memory", "x6", "x7", "x8", "v0", "v1", "v2", "v3", "v4", "v5", |
|
"v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", |
|
"v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", |
|
"v26", "v27", "v28", "v29", "v30", "v31"); |
|
|
|
} else { |
|
asm( |
|
|
|
"ldrsh x7, [%[col_deltas_bytes]], #2\n" |
|
"ldrsh x8, [%[col_deltas_bytes]], #2\n" |
|
|
|
"add %[rhs_ptr], %[rhs_ptr], x7\n" |
|
|
|
"movi v25.4s, #0\n" |
|
|
|
LABEL_ROW_LOOP |
|
":\n" |
|
|
|
|
|
"ld1 {v27.4s}, [%[bias_ptr]], #16\n" |
|
|
|
|
|
"dup v28.4s, v27.s[0]\n" |
|
"dup v29.4s, v27.s[1]\n" |
|
"dup v30.4s, v27.s[2]\n" |
|
"dup v31.4s, v27.s[3]\n" |
|
|
|
|
|
"ldr w6, [%[nnz_per_row]], #4\n" |
|
"cmp w6, #0\n" |
|
|
|
"beq " LABEL_SKIP_COL_LOOP "f\n" |
|
|
|
LABEL_COL_LOOP |
|
":\n" |
|
|
|
"ld1 {v0.4h}, [%[rhs_ptr]], x8\n" |
|
|
|
"mov v0.d[1], v0.d[0]\n" |
|
|
|
|
|
"ldrsh x8, [%[col_deltas_bytes]], #2\n" |
|
|
|
|
|
"ld1 {v2.8h, v3.8h}, [%[weights_ptr]], #32\n" |
|
|
|
|
|
"smlal v28.4s, v2.4h, v0.4h\n" |
|
"smlal2 v29.4s, v2.8h, v0.8h\n" |
|
"smlal v30.4s, v3.4h, v0.4h\n" |
|
"smlal2 v31.4s, v3.8h, v0.8h\n" |
|
|
|
|
|
"subs w6, w6, #1\n" |
|
"bne " LABEL_COL_LOOP "b\n" |
|
|
|
LABEL_SKIP_COL_LOOP |
|
":\n" |
|
|
|
|
|
"addp v28.4s, v28.4s, v29.4s\n" |
|
"addp v30.4s, v30.4s, v31.4s\n" |
|
"addp v28.4s, v28.4s, v30.4s\n" |
|
|
|
"srshr v28.4s, v28.4s, %[shift_amount]\n" |
|
|
|
"st1 {v28.4s}, [%[out_ptr]], #16\n" |
|
|
|
|
|
"subs %[assigned_rows], %[assigned_rows], #1\n" |
|
"bne " LABEL_ROW_LOOP "b\n" |
|
|
|
|
|
: |
|
[out_ptr] "+r"(out_ptr), [weights_ptr] "+r"(weights_ptr), |
|
[col_deltas_bytes] "+r"(col_deltas_bytes), [bias_ptr] "+r"(bias_ptr), |
|
[nnz_per_row] "+r"(nnz_per_row), [assigned_rows] "+r"(assigned_rows), |
|
[rhs_ptr] "+r"(rhs_ptr) |
|
: |
|
[shift_amount] "I"(kShiftAmount) |
|
: |
|
"cc", "memory", "x6", "x7", "x8", "v0", "v1", "v2", "v3", "v4", "v5", |
|
"v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", |
|
"v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", |
|
"v26", "v27", "v28", "v29", "v30", "v31"); |
|
|
|
} |
|
} |
|
|
|
|
|
|
|
template <typename WeightType, typename RhsType, typename OutType> |
|
typename std::enable_if< |
|
IsFixed16Type<WeightType>::value && IsFixed16Type<RhsType>::value && |
|
IsFixed32Type<OutType>::value && |
|
!std::is_same<OutType, typename TypeOfProduct<WeightType, |
|
RhsType>::type>::value>::type |
|
SpMM5_4x4(const WeightType* weights_ptr, const int16_t* col_deltas_bytes, |
|
const int32_t* nnz_per_row, const RhsType* rhs_ptr, |
|
const typename TypeOfProduct<WeightType, RhsType>::type* bias_ptr, |
|
OutType* out_ptr, int64_t assigned_rows, int64_t rows, int64_t cols, |
|
int relu) { |
|
constexpr int kShiftAmount = |
|
TypeOfProduct<WeightType, RhsType>::type::kMantissaBits - |
|
OutType::kMantissaBits; |
|
static_assert(kShiftAmount > 0, |
|
"Result must have fewer mantissa bits than product"); |
|
|
|
const RhsType* rhs2_ptr = rhs_ptr + cols; |
|
OutType* out2_ptr = out_ptr + rows; |
|
const RhsType* rhs3_ptr = rhs_ptr + 2 * cols; |
|
OutType* out3_ptr = out_ptr + 2 * rows; |
|
const RhsType* rhs4_ptr = rhs_ptr + 3 * cols; |
|
OutType* out4_ptr = out_ptr + 3 * rows; |
|
const RhsType* rhs5_ptr = rhs_ptr + 4 * cols; |
|
OutType* out5_ptr = out_ptr + 4 * rows; |
|
if (relu) { |
|
asm( |
|
|
|
"ldrsh x7, [%[col_deltas_bytes]], #2\n" |
|
"ldrsh x8, [%[col_deltas_bytes]], #2\n" |
|
|
|
"add %[rhs_ptr], %[rhs_ptr], x7\n" |
|
"add %[rhs2_ptr], %[rhs2_ptr], x7\n" |
|
"add %[rhs3_ptr], %[rhs3_ptr], x7\n" |
|
"add %[rhs4_ptr], %[rhs4_ptr], x7\n" |
|
"add %[rhs5_ptr], %[rhs5_ptr], x7\n" |
|
|
|
LABEL_ROW_LOOP |
|
":\n" |
|
|
|
|
|
"ld1 {v27.4s}, [%[bias_ptr]], #16\n" |
|
|
|
|
|
"dup v28.4s, v27.s[0]\n" |
|
"dup v29.4s, v27.s[1]\n" |
|
"dup v30.4s, v27.s[2]\n" |
|
"dup v31.4s, v27.s[3]\n" |
|
"dup v23.4s, v27.s[0]\n" |
|
"dup v24.4s, v27.s[1]\n" |
|
"dup v25.4s, v27.s[2]\n" |
|
"dup v26.4s, v27.s[3]\n" |
|
"dup v19.4s, v27.s[0]\n" |
|
"dup v20.4s, v27.s[1]\n" |
|
"dup v21.4s, v27.s[2]\n" |
|
"dup v22.4s, v27.s[3]\n" |
|
"dup v15.4s, v27.s[0]\n" |
|
"dup v16.4s, v27.s[1]\n" |
|
"dup v17.4s, v27.s[2]\n" |
|
"dup v18.4s, v27.s[3]\n" |
|
"dup v11.4s, v27.s[0]\n" |
|
"dup v12.4s, v27.s[1]\n" |
|
"dup v13.4s, v27.s[2]\n" |
|
"dup v14.4s, v27.s[3]\n" |
|
|
|
|
|
"ldr w6, [%[nnz_per_row]], #4\n" |
|
"cmp w6, #0\n" |
|
|
|
"beq " LABEL_SKIP_COL_LOOP "f\n" |
|
|
|
LABEL_COL_LOOP |
|
":\n" |
|
|
|
"ld1 {v0.4h}, [%[rhs_ptr]], x8\n" |
|
"mov v0.d[1], v0.d[0]\n" |
|
"ld1 {v1.4h}, [%[rhs2_ptr]], x8\n" |
|
"mov v1.d[1], v1.d[0]\n" |
|
"ld1 {v8.4h}, [%[rhs3_ptr]], x8\n" |
|
"mov v8.d[1], v8.d[0]\n" |
|
"ld1 {v9.4h}, [%[rhs4_ptr]], x8\n" |
|
"mov v9.d[1], v9.d[0]\n" |
|
"ld1 {v10.4h}, [%[rhs5_ptr]], x8\n" |
|
"mov v10.d[1], v10.d[0]\n" |
|
|
|
|
|
"ldrsh x8, [%[col_deltas_bytes]], #2\n" |
|
|
|
|
|
"ld1 {v2.8h, v3.8h}, [%[weights_ptr]], #32\n" |
|
|
|
|
|
"smlal v28.4s, v2.4h, v0.4h\n" |
|
"smlal2 v29.4s, v2.8h, v0.8h\n" |
|
"smlal v30.4s, v3.4h, v0.4h\n" |
|
"smlal2 v31.4s, v3.8h, v0.8h\n" |
|
"smlal v23.4s, v2.4h, v1.4h\n" |
|
"smlal2 v24.4s, v2.8h, v1.8h\n" |
|
"smlal v25.4s, v3.4h, v1.4h\n" |
|
"smlal2 v26.4s, v3.8h, v1.8h\n" |
|
"smlal v19.4s, v2.4h, v8.4h\n" |
|
"smlal2 v20.4s, v2.8h, v8.8h\n" |
|
"smlal v21.4s, v3.4h, v8.4h\n" |
|
"smlal2 v22.4s, v3.8h, v8.8h\n" |
|
"smlal v15.4s, v2.4h, v9.4h\n" |
|
"smlal2 v16.4s, v2.8h, v9.8h\n" |
|
"smlal v17.4s, v3.4h, v9.4h\n" |
|
"smlal2 v18.4s, v3.8h, v9.8h\n" |
|
"smlal v11.4s, v2.4h, v10.4h\n" |
|
"smlal2 v12.4s, v2.8h, v10.8h\n" |
|
"smlal v13.4s, v3.4h, v10.4h\n" |
|
"smlal2 v14.4s, v3.8h, v10.8h\n" |
|
|
|
|
|
"subs w6, w6, #1\n" |
|
"bne " LABEL_COL_LOOP "b\n" |
|
|
|
LABEL_SKIP_COL_LOOP |
|
":\n" |
|
|
|
"movi v0.4s, #0\n" |
|
"addp v28.4s, v28.4s, v29.4s\n" |
|
"addp v23.4s, v23.4s, v24.4s\n" |
|
"addp v19.4s, v19.4s, v20.4s\n" |
|
"addp v15.4s, v15.4s, v16.4s\n" |
|
"addp v11.4s, v11.4s, v12.4s\n" |
|
|
|
"addp v30.4s, v30.4s, v31.4s\n" |
|
"addp v25.4s, v25.4s, v26.4s\n" |
|
"addp v21.4s, v21.4s, v22.4s\n" |
|
"addp v17.4s, v17.4s, v18.4s\n" |
|
"addp v13.4s, v13.4s, v14.4s\n" |
|
|
|
"addp v28.4s, v28.4s, v30.4s\n" |
|
"addp v23.4s, v23.4s, v25.4s\n" |
|
"addp v19.4s, v19.4s, v21.4s\n" |
|
"addp v15.4s, v15.4s, v17.4s\n" |
|
"addp v11.4s, v11.4s, v13.4s\n" |
|
|
|
|
|
"smax v28.4s, v28.4s, v0.4s\n" |
|
"smax v23.4s, v23.4s, v0.4s\n" |
|
"smax v19.4s, v19.4s, v0.4s\n" |
|
"smax v15.4s, v15.4s, v0.4s\n" |
|
"smax v11.4s, v11.4s, v0.4s\n" |
|
|
|
"srshr v28.4s, v28.4s, %[shift_amount]\n" |
|
"srshr v23.4s, v23.4s, %[shift_amount]\n" |
|
"srshr v19.4s, v19.4s, %[shift_amount]\n" |
|
"srshr v15.4s, v15.4s, %[shift_amount]\n" |
|
"srshr v11.4s, v11.4s, %[shift_amount]\n" |
|
|
|
|
|
"st1 {v28.4s}, [%[out_ptr]], #16\n" |
|
"st1 {v23.4s}, [%[out2_ptr]], #16\n" |
|
"st1 {v19.4s}, [%[out3_ptr]], #16\n" |
|
"st1 {v15.4s}, [%[out4_ptr]], #16\n" |
|
"st1 {v11.4s}, [%[out5_ptr]], #16\n" |
|
|
|
|
|
"subs %[assigned_rows], %[assigned_rows], #1\n" |
|
"bne " LABEL_ROW_LOOP "b\n" |
|
|
|
|
|
: |
|
[out_ptr] "+r"(out_ptr), [out2_ptr] "+r"(out2_ptr), |
|
[out3_ptr] "+r"(out3_ptr), [out4_ptr] "+r"(out4_ptr), |
|
[out5_ptr] "+r"(out5_ptr), [weights_ptr] "+r"(weights_ptr), |
|
[col_deltas_bytes] "+r"(col_deltas_bytes), [bias_ptr] "+r"(bias_ptr), |
|
[nnz_per_row] "+r"(nnz_per_row), [assigned_rows] "+r"(assigned_rows), |
|
[rhs_ptr] "+r"(rhs_ptr), [rhs2_ptr] "+r"(rhs2_ptr), |
|
[rhs3_ptr] "+r"(rhs3_ptr), [rhs4_ptr] "+r"(rhs4_ptr), |
|
[rhs5_ptr] "+r"(rhs5_ptr) |
|
: |
|
[shift_amount] "I"(kShiftAmount) |
|
: |
|
"cc", "memory", "x6", "x7", "x8", "v0", "v1", "v2", "v3", "v4", "v5", |
|
"v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", |
|
"v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", |
|
"v26", "v27", "v28", "v29", "v30", "v31"); |
|
|
|
} else { |
|
asm( |
|
|
|
"ldrsh x7, [%[col_deltas_bytes]], #2\n" |
|
"ldrsh x8, [%[col_deltas_bytes]], #2\n" |
|
|
|
"add %[rhs_ptr], %[rhs_ptr], x7\n" |
|
"add %[rhs2_ptr], %[rhs2_ptr], x7\n" |
|
"add %[rhs3_ptr], %[rhs3_ptr], x7\n" |
|
"add %[rhs4_ptr], %[rhs4_ptr], x7\n" |
|
"add %[rhs5_ptr], %[rhs5_ptr], x7\n" |
|
|
|
LABEL_ROW_LOOP |
|
":\n" |
|
|
|
|
|
"ld1 {v27.4s}, [%[bias_ptr]], #16\n" |
|
|
|
|
|
"dup v28.4s, v27.s[0]\n" |
|
"dup v29.4s, v27.s[1]\n" |
|
"dup v30.4s, v27.s[2]\n" |
|
"dup v31.4s, v27.s[3]\n" |
|
"dup v23.4s, v27.s[0]\n" |
|
"dup v24.4s, v27.s[1]\n" |
|
"dup v25.4s, v27.s[2]\n" |
|
"dup v26.4s, v27.s[3]\n" |
|
"dup v19.4s, v27.s[0]\n" |
|
"dup v20.4s, v27.s[1]\n" |
|
"dup v21.4s, v27.s[2]\n" |
|
"dup v22.4s, v27.s[3]\n" |
|
"dup v15.4s, v27.s[0]\n" |
|
"dup v16.4s, v27.s[1]\n" |
|
"dup v17.4s, v27.s[2]\n" |
|
"dup v18.4s, v27.s[3]\n" |
|
"dup v11.4s, v27.s[0]\n" |
|
"dup v12.4s, v27.s[1]\n" |
|
"dup v13.4s, v27.s[2]\n" |
|
"dup v14.4s, v27.s[3]\n" |
|
|
|
|
|
"ldr w6, [%[nnz_per_row]], #4\n" |
|
"cmp w6, #0\n" |
|
|
|
"beq " LABEL_SKIP_COL_LOOP "f\n" |
|
|
|
LABEL_COL_LOOP |
|
":\n" |
|
|
|
"ld1 {v0.4h}, [%[rhs_ptr]], x8\n" |
|
"mov v0.d[1], v0.d[0]\n" |
|
"ld1 {v1.4h}, [%[rhs2_ptr]], x8\n" |
|
"mov v1.d[1], v1.d[0]\n" |
|
"ld1 {v8.4h}, [%[rhs3_ptr]], x8\n" |
|
"mov v8.d[1], v8.d[0]\n" |
|
"ld1 {v9.4h}, [%[rhs4_ptr]], x8\n" |
|
"mov v9.d[1], v9.d[0]\n" |
|
"ld1 {v10.4h}, [%[rhs5_ptr]], x8\n" |
|
"mov v10.d[1], v10.d[0]\n" |
|
|
|
|
|
"ldrsh x8, [%[col_deltas_bytes]], #2\n" |
|
|
|
|
|
"ld1 {v2.8h, v3.8h}, [%[weights_ptr]], #32\n" |
|
|
|
|
|
"smlal v28.4s, v2.4h, v0.4h\n" |
|
"smlal2 v29.4s, v2.8h, v0.8h\n" |
|
"smlal v30.4s, v3.4h, v0.4h\n" |
|
"smlal2 v31.4s, v3.8h, v0.8h\n" |
|
"smlal v23.4s, v2.4h, v1.4h\n" |
|
"smlal2 v24.4s, v2.8h, v1.8h\n" |
|
"smlal v25.4s, v3.4h, v1.4h\n" |
|
"smlal2 v26.4s, v3.8h, v1.8h\n" |
|
"smlal v19.4s, v2.4h, v8.4h\n" |
|
"smlal2 v20.4s, v2.8h, v8.8h\n" |
|
"smlal v21.4s, v3.4h, v8.4h\n" |
|
"smlal2 v22.4s, v3.8h, v8.8h\n" |
|
"smlal v15.4s, v2.4h, v9.4h\n" |
|
"smlal2 v16.4s, v2.8h, v9.8h\n" |
|
"smlal v17.4s, v3.4h, v9.4h\n" |
|
"smlal2 v18.4s, v3.8h, v9.8h\n" |
|
"smlal v11.4s, v2.4h, v10.4h\n" |
|
"smlal2 v12.4s, v2.8h, v10.8h\n" |
|
"smlal v13.4s, v3.4h, v10.4h\n" |
|
"smlal2 v14.4s, v3.8h, v10.8h\n" |
|
|
|
|
|
"subs w6, w6, #1\n" |
|
"bne " LABEL_COL_LOOP "b\n" |
|
|
|
LABEL_SKIP_COL_LOOP |
|
":\n" |
|
|
|
"addp v28.4s, v28.4s, v29.4s\n" |
|
"addp v23.4s, v23.4s, v24.4s\n" |
|
"addp v19.4s, v19.4s, v20.4s\n" |
|
"addp v15.4s, v15.4s, v16.4s\n" |
|
"addp v11.4s, v11.4s, v12.4s\n" |
|
|
|
"addp v30.4s, v30.4s, v31.4s\n" |
|
"addp v25.4s, v25.4s, v26.4s\n" |
|
"addp v21.4s, v21.4s, v22.4s\n" |
|
"addp v17.4s, v17.4s, v18.4s\n" |
|
"addp v13.4s, v13.4s, v14.4s\n" |
|
|
|
"addp v28.4s, v28.4s, v30.4s\n" |
|
"addp v23.4s, v23.4s, v25.4s\n" |
|
"addp v19.4s, v19.4s, v21.4s\n" |
|
"addp v15.4s, v15.4s, v17.4s\n" |
|
"addp v11.4s, v11.4s, v13.4s\n" |
|
|
|
"srshr v28.4s, v28.4s, %[shift_amount]\n" |
|
"srshr v23.4s, v23.4s, %[shift_amount]\n" |
|
"srshr v19.4s, v19.4s, %[shift_amount]\n" |
|
"srshr v15.4s, v15.4s, %[shift_amount]\n" |
|
"srshr v11.4s, v11.4s, %[shift_amount]\n" |
|
|
|
|
|
"st1 {v28.4s}, [%[out_ptr]], #16\n" |
|
"st1 {v23.4s}, [%[out2_ptr]], #16\n" |
|
"st1 {v19.4s}, [%[out3_ptr]], #16\n" |
|
"st1 {v15.4s}, [%[out4_ptr]], #16\n" |
|
"st1 {v11.4s}, [%[out5_ptr]], #16\n" |
|
|
|
|
|
"subs %[assigned_rows], %[assigned_rows], #1\n" |
|
"bne " LABEL_ROW_LOOP "b\n" |
|
|
|
|
|
: |
|
[out_ptr] "+r"(out_ptr), [out2_ptr] "+r"(out2_ptr), |
|
[out3_ptr] "+r"(out3_ptr), [out4_ptr] "+r"(out4_ptr), |
|
[out5_ptr] "+r"(out5_ptr), [weights_ptr] "+r"(weights_ptr), |
|
[col_deltas_bytes] "+r"(col_deltas_bytes), [bias_ptr] "+r"(bias_ptr), |
|
[nnz_per_row] "+r"(nnz_per_row), [assigned_rows] "+r"(assigned_rows), |
|
[rhs_ptr] "+r"(rhs_ptr), [rhs2_ptr] "+r"(rhs2_ptr), |
|
[rhs3_ptr] "+r"(rhs3_ptr), [rhs4_ptr] "+r"(rhs4_ptr), |
|
[rhs5_ptr] "+r"(rhs5_ptr) |
|
: |
|
[shift_amount] "I"(kShiftAmount) |
|
: |
|
"cc", "memory", "x6", "x7", "x8", "v0", "v1", "v2", "v3", "v4", "v5", |
|
"v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", |
|
"v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", |
|
"v26", "v27", "v28", "v29", "v30", "v31"); |
|
|
|
} |
|
} |
|
|
|
template <typename Type> |
|
typename std::enable_if<IsFixed32Type<Type>::value>::type SumVectors( |
|
int start, int end, const Type* add1, const Type* add2, Type* result) { |
|
constexpr int kSIMDWidth = 4; |
|
for (int i = start; i < end; i += kSIMDWidth) { |
|
int32x4_t add1_int = vld1q_s32(reinterpret_cast<const int32_t*>(add1 + i)); |
|
int32x4_t add2_int = vld1q_s32(reinterpret_cast<const int32_t*>(add2 + i)); |
|
int32x4_t result_int = vqaddq_s32(add1_int, add2_int); |
|
vst1q_s32(reinterpret_cast<int32_t*>(result + i), result_int); |
|
} |
|
} |
|
|
|
template <typename Type> |
|
typename std::enable_if<IsFixed16Type<Type>::value>::type SumVectors( |
|
int start, int end, const Type* add1, const Type* add2, Type* result) { |
|
constexpr int kSIMDWidth = 8; |
|
for (int i = start; i < end; i += kSIMDWidth) { |
|
int16x8_t add1_int = vld1q_s16(reinterpret_cast<const int16_t*>(add1 + i)); |
|
int16x8_t add2_int = vld1q_s16(reinterpret_cast<const int16_t*>(add2 + i)); |
|
int16x8_t result_int = vqaddq_s16(add1_int, add2_int); |
|
vst1q_s16(reinterpret_cast<int16_t*>(result + i), result_int); |
|
} |
|
} |
|
|
|
} |
|
} |
|
|
|
#undef LABEL_COL_LOOP |
|
#undef LABEL_ROW_LOOP |
|
#undef LABEL_SKIP_COL_LOOP |
|
#undef LABEL_TOP_LOOP |
|
|
|
#endif |
|
#endif |
|
|