test / src /f32-gemm /neon-ld64.c.in
Androidonnxfork's picture
Upload folder using huggingface_hub
8b7c501
raw
history blame
No virus
8.94 kB
// Copyright 2019 Google LLC
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.
$assert NR % 4 == 0
$assert DATATYPE in ["F32", "QC4", "QC8"]
$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"
$VMULADDQ_F32 = "vfmaq_f32" if FMA else "vmlaq_f32"
$VMULADDQ_LANE_F32 = "vfmaq_lane_f32" if FMA else "vmlaq_lane_f32"
#include <assert.h>
#include <arm_neon.h>
#include <xnnpack/gemm.h>
$ISA = ("neonfma" if DUP else "aarch64_neonfma") if FMA else "neon"
$DATATYPE_SPEC = {"F32": "f32", "QC8": "f32_qc8w", "QC4": "f32_qc4w"}[DATATYPE]
void xnn_${DATATYPE_SPEC}_gemm${"inc" if INC else ""}_minmax_ukernel_${MR}x${NR}__${ISA}_${"dup" if DUP else "lane"}_ld64(
size_t mr,
size_t nc,
size_t kc,
const float* restrict a,
size_t a_stride,
$if DATATYPE == "F32":
const float* restrict w,
$else:
const void* restrict w,
float* restrict c,
size_t cm_stride,
size_t cn_stride,
$if INC:
const float* restrict acc,
$if DATATYPE == "QC4":
const union xnn_f32_qc4w_minmax_params params[restrict XNN_MIN_ELEMENTS(1)])
$else:
const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS(1)])
{
assert(mr != 0);
assert(mr <= ${MR});
assert(nc != 0);
assert(kc != 0);
assert(kc % sizeof(float) == 0);
assert(a != NULL);
assert(w != NULL);
assert(c != NULL);
$if INC:
assert(acc != NULL);
const float* a0 = a;
float* c0 = c;
$for M in range(1, MR):
const float* a${M} = (const float*) ((uintptr_t) a${M-1} + a_stride);
float* c${M} = (float*) ((uintptr_t) c${M-1} + cm_stride);
$if M % 2 == 0:
if XNN_UNPREDICTABLE(mr <= ${M}) {
a${M} = a${M-1};
c${M} = c${M-1};
}
$elif M + 1 == MR:
if XNN_UNPREDICTABLE(mr != ${M+1}) {
a${M} = a${M-1};
c${M} = c${M-1};
}
$else:
if XNN_UNPREDICTABLE(mr < ${M+1}) {
a${M} = a${M-1};
c${M} = c${M-1};
}
$if DATATYPE == "QC4":
const int16x8_t vminus_kernel_zero_point = vld1q_dup_s16(&params->scalar.minus_kernel_zero_point[0]);
const uint8x8_t vmask = vmov_n_u8(UINT8_C(0xF));
do {
$if INC:
$for M in range(MR):
$for N in range(0, NR, 4):
float32x4_t vacc${M}x${ABC[N:N+4]} = vld1q_f32(acc); acc += 4;
$else:
$for N in range(0, NR, 4):
$if DATATYPE == "F32":
float32x4_t vacc0x${ABC[N:N+4]} = vld1q_f32(w); w += 4;
$else:
float32x4_t vacc0x${ABC[N:N+4]} = vld1q_f32(w); w = (const float*) w + 4;
$for M in range(1, MR):
$for N in range(0, NR, 4):
float32x4_t vacc${M}x${ABC[N:N+4]} = vacc0x${ABC[N:N+4]};
size_t k = kc;
for (; k >= 2 * sizeof(float); k -= 2 * sizeof(float)) {
$for M in range(MR):
const float32x2_t va${M} = vld1_f32(a${M}); a${M} += 2;
$if DATATYPE == "F32":
$for L in range(2):
$for N in range(0, NR, 4):
const float32x4_t vb${ABC[N:N+4]}c${L} = vld1q_f32(w); w += 4;
$elif DATATYPE == "QC4":
$for N in range(0, NR, 8):
const uint8x8_t vw${ABC[N:N+8]}c01 = vld1_u8(w); w = (const uint8_t*) w + 8;
const uint8x8_t vw${ABC[N:N+8]}c0 = vand_u8(vw${ABC[N:N+8]}c01, vmask);
const uint8x8_t vw${ABC[N:N+8]}c1 = vshr_n_u8(vw${ABC[N:N+8]}c01, 4);
$for L in range(2):
const int16x8_t vxw${ABC[N:N+8]}c${L} = vaddw_s8(vminus_kernel_zero_point, vreinterpret_s8_u8(vw${ABC[N:N+8]}c${L}));
$for L in range(2):
const int32x4_t vxw${ABC[N:N+4]}c${L} = vmovl_s16(vget_low_s16(vxw${ABC[N:N+8]}c${L}));
const int32x4_t vxw${ABC[N+4:N+8]}c${L} = vmovl_s16(vget_high_s16(vxw${ABC[N:N+8]}c${L}));
$for N in range(0, NR, 4):
$for L in range(2):
const float32x4_t vb${ABC[N:N+4]}c${L} = vcvtq_f32_s32(vxw${ABC[N:N+4]}c${L});
$elif DATATYPE == "QC8":
$for N in range(0, NR, 8):
$for L in range(2):
const int8x8_t vw${ABC[N:N+8]}c${L} = vld1_s8(w); w = (const int8_t*) w + 8;
$for L in range(2):
const int16x8_t vxw${ABC[N:N+8]}c${L} = vmovl_s8(vw${ABC[N:N+8]}c${L});
$for L in range(2):
const int32x4_t vxw${ABC[N:N+4]}c${L} = vmovl_s16(vget_low_s16(vxw${ABC[N:N+8]}c${L}));
const int32x4_t vxw${ABC[N+4:N+8]}c${L} = vmovl_s16(vget_high_s16(vxw${ABC[N:N+8]}c${L}));
$for N in range(0, NR, 4):
$for L in range(2):
const float32x4_t vb${ABC[N:N+4]}c${L} = vcvtq_f32_s32(vxw${ABC[N:N+4]}c${L});
$for L in range(2):
$if DUP:
$for M in range(MR):
const float32x4_t va${M}c${L} = vdupq_lane_f32(va${M}, ${L});
$for N in range(0, NR, 4):
$for M in range(MR):
vacc${M}x${ABC[N:N+4]} = ${VMULADDQ_F32}(vacc${M}x${ABC[N:N+4]}, va${M}c${L}, vb${ABC[N:N+4]}c${L});
$else:
$for N in range(0, NR, 4):
$for M in range(MR):
vacc${M}x${ABC[N:N+4]} = ${VMULADDQ_LANE_F32}(vacc${M}x${ABC[N:N+4]}, vb${ABC[N:N+4]}c${L}, va${M}, ${L});
}
if XNN_UNLIKELY(k != 0) {
$for M in range(MR):
const float32x4_t va${M} = vld1q_dup_f32(a${M}); a${M} += 1;
$if DATATYPE == "F32":
$for N in range(0, NR, 4):
const float32x4_t vb${ABC[N:N+4]} = vld1q_f32(w); w += 4;
$elif DATATYPE == "QC4":
$for N in range(0, NR, 8):
const uint8x8_t vw${ABC[N:N+8]} = vld1_u8(w); w = (const uint8_t*) w + 8;
const int16x8_t vxw${ABC[N:N+8]} = vaddw_s8(vminus_kernel_zero_point, vreinterpret_s8_u8(vw${ABC[N:N+8]}));
const int32x4_t vxw${ABC[N:N+4]} = vmovl_s16(vget_low_s16(vxw${ABC[N:N+8]}));
const int32x4_t vxw${ABC[N+4:N+8]} = vmovl_s16(vget_high_s16(vxw${ABC[N:N+8]}));
$for N in range(0, NR, 4):
const float32x4_t vb${ABC[N:N+4]} = vcvtq_f32_s32(vxw${ABC[N:N+4]});
$elif DATATYPE == "QC8":
$for N in range(0, NR, 4):
const int8x8_t vw${ABC[N:N+4]}${ABC[N:N+4]} = vreinterpret_s8_u32(vld1_dup_u32(w)); w = (const int8_t*) w + 4;
$for N in range(0, NR, 4):
const int16x8_t vxw${ABC[N:N+4]}${ABC[N:N+4]} = vmovl_s8(vw${ABC[N:N+4]}${ABC[N:N+4]});
$for N in range(0, NR, 4):
const int32x4_t vxw${ABC[N:N+4]} = vmovl_s16(vget_low_s16(vxw${ABC[N:N+4]}${ABC[N:N+4]}));
$for N in range(0, NR, 4):
const float32x4_t vb${ABC[N:N+4]} = vcvtq_f32_s32(vxw${ABC[N:N+4]});
$for N in range(0, NR, 4):
$for M in range(MR):
vacc${M}x${ABC[N:N+4]} = ${VMULADDQ_F32}(vacc${M}x${ABC[N:N+4]}, va${M}, vb${ABC[N:N+4]});
}
$if DATATYPE in ["QC8", "QC4"]:
$for N in range(0, NR, 4):
const float32x4_t vscale${ABC[N:N+4]} = vld1q_f32(w); w = (const float*) w + 4;
$for M in range(MR):
vacc${M}x${ABC[N:N+4]} = vmulq_f32(vacc${M}x${ABC[N:N+4]}, vscale${ABC[N:N+4]});
const float32x4_t vmax = vld1q_dup_f32(&params->scalar.max);
$for N in range(0, NR, 4):
$for M in range(MR):
vacc${M}x${ABC[N:N+4]} = vminq_f32(vacc${M}x${ABC[N:N+4]}, vmax);
const float32x4_t vmin = vld1q_dup_f32(&params->scalar.min);
$for N in range(0, NR, 4):
$for M in range(MR):
vacc${M}x${ABC[N:N+4]} = vmaxq_f32(vacc${M}x${ABC[N:N+4]}, vmin);
if XNN_LIKELY(nc >= ${NR}) {
$for M in reversed(range(MR)):
vst1q_f32(c${M}, vacc${M}x${ABC[0:4]});
$for N in range(4, NR, 4):
vst1q_f32(c${M} + ${N}, vacc${M}x${ABC[N:N+4]});
c${M} = (float*) ((uintptr_t) c${M} + cn_stride);
$for M in reversed(range(MR)):
a${M} = (const float*) ((uintptr_t) a${M} - kc);
nc -= ${NR};
} else {
$for LOG2N in reversed(range(NR.bit_length())):
$if NR != 1 << LOG2N:
if (nc & ${1 << LOG2N}) {
$if LOG2N >= 2:
$for N in range(0, 1 << LOG2N, 4):
$for M in reversed(range(MR)):
vst1q_f32(c${M}, vacc${M}x${ABC[N:N+4]}); c${M} += 4;
$for M in reversed(range(MR)):
$for N in range(0, NR - (1 << LOG2N), 4):
vacc${M}x${ABC[N:N+4]} = vacc${M}x${ABC[N + (1 << LOG2N):N + (1 << LOG2N)+4]};
$elif LOG2N == 1:
$for M in reversed(range(MR)):
vst1_f32(c${M}, vacc${M}x${ABC[0:2]}); c${M} += 2;
$for M in reversed(range(MR)):
vacc${M}x${ABC[0:2]} = vget_high_f32(vacc${M}x${ABC[0:4]});
$elif LOG2N == 0:
$for M in reversed(range(MR)):
vst1_lane_f32(c${M}, vacc${M}x${ABC[0:2]}, 0);
}
$if LOG2N == 2:
$for M in reversed(range(MR)):
float32x2_t vacc${M}x${ABC[0:2]} = vget_low_f32(vacc${M}x${ABC[0:4]});
nc = 0;
}
} while (nc != 0);
}