test / src /cs16-bfly4 /cs16-bfly4-neon-x4.c
Androidonnxfork's picture
Upload folder using huggingface_hub
8b7c501
raw
history blame
No virus
5.11 kB
// Copyright 2022 Google LLC
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.
#include <assert.h>
#include <stddef.h>
#include <stdint.h>
#include <xnnpack/math.h>
#include <xnnpack/fft.h>
#include <arm_neon.h>
void xnn_cs16_bfly4_ukernel__neon_x4(
size_t batch,
size_t samples,
int16_t* data,
const int16_t* twiddle,
size_t stride)
{
assert(batch != 0);
assert(samples != 0);
assert(samples % (sizeof(int16_t) * 8) == 0);
assert(data != NULL);
assert(stride != 0);
assert(twiddle != NULL);
const int16x4_t vdiv4 = vdup_n_s16(8191);
int16_t* data3 = data;
do {
int16_t* data0 = data3;
int16_t* data1 = (int16_t*) ((uintptr_t) data0 + samples);
int16_t* data2 = (int16_t*) ((uintptr_t) data1 + samples);
data3 = (int16_t*) ((uintptr_t) data2 + samples);
const int16_t* tw1 = twiddle;
const int16_t* tw2 = twiddle;
const int16_t* tw3 = twiddle;
size_t s = samples;
for (; s >= sizeof(int16_t) * 8; s -= sizeof(int16_t) * 8) {
int16x4x2_t vout0 = vld2_s16(data0);
int16x4x2_t vout1 = vld2_s16(data1);
int16x4x2_t vout2 = vld2_s16(data2);
int16x4x2_t vout3 = vld2_s16(data3);
int16x4x2_t vtw1 = vld2_dup_s16(tw1);
int16x4x2_t vtw2 = vld2_dup_s16(tw2);
int16x4x2_t vtw3 = vld2_dup_s16(tw3);
tw1 = (const int16_t*) ((uintptr_t) tw1 + stride);
tw2 = (const int16_t*) ((uintptr_t) tw2 + stride * 2);
tw3 = (const int16_t*) ((uintptr_t) tw3 + stride * 3);
vtw1 = vld2_lane_s16(tw1, vtw1, 1);
vtw2 = vld2_lane_s16(tw2, vtw2, 1);
vtw3 = vld2_lane_s16(tw3, vtw3, 1);
tw1 = (const int16_t*) ((uintptr_t) tw1 + stride);
tw2 = (const int16_t*) ((uintptr_t) tw2 + stride * 2);
tw3 = (const int16_t*) ((uintptr_t) tw3 + stride * 3);
vtw1 = vld2_lane_s16(tw1, vtw1, 2);
vtw2 = vld2_lane_s16(tw2, vtw2, 2);
vtw3 = vld2_lane_s16(tw3, vtw3, 2);
tw1 = (const int16_t*) ((uintptr_t) tw1 + stride);
tw2 = (const int16_t*) ((uintptr_t) tw2 + stride * 2);
tw3 = (const int16_t*) ((uintptr_t) tw3 + stride * 3);
vtw1 = vld2_lane_s16(tw1, vtw1, 3);
vtw2 = vld2_lane_s16(tw2, vtw2, 3);
vtw3 = vld2_lane_s16(tw3, vtw3, 3);
tw1 = (const int16_t*) ((uintptr_t) tw1 + stride);
tw2 = (const int16_t*) ((uintptr_t) tw2 + stride * 2);
tw3 = (const int16_t*) ((uintptr_t) tw3 + stride * 3);
// Note 32767 / 4 = 8191. Should be 8192.
vout1.val[0] = vqrdmulh_s16(vout1.val[0], vdiv4);
vout1.val[1] = vqrdmulh_s16(vout1.val[1], vdiv4);
vout2.val[0] = vqrdmulh_s16(vout2.val[0], vdiv4);
vout2.val[1] = vqrdmulh_s16(vout2.val[1], vdiv4);
vout3.val[0] = vqrdmulh_s16(vout3.val[0], vdiv4);
vout3.val[1] = vqrdmulh_s16(vout3.val[1], vdiv4);
vout0.val[0] = vqrdmulh_s16(vout0.val[0], vdiv4);
vout0.val[1] = vqrdmulh_s16(vout0.val[1], vdiv4);
int32x4_t vacc0r = vmull_s16(vout1.val[0], vtw1.val[0]);
int32x4_t vacc1r = vmull_s16(vout2.val[0], vtw2.val[0]);
int32x4_t vacc2r = vmull_s16(vout3.val[0], vtw3.val[0]);
int32x4_t vacc0i = vmull_s16(vout1.val[0], vtw1.val[1]);
int32x4_t vacc1i = vmull_s16(vout2.val[0], vtw2.val[1]);
int32x4_t vacc2i = vmull_s16(vout3.val[0], vtw3.val[1]);
vacc0r = vmlsl_s16(vacc0r, vout1.val[1], vtw1.val[1]);
vacc1r = vmlsl_s16(vacc1r, vout2.val[1], vtw2.val[1]);
vacc2r = vmlsl_s16(vacc2r, vout3.val[1], vtw3.val[1]);
vacc0i = vmlal_s16(vacc0i, vout1.val[1], vtw1.val[0]);
vacc1i = vmlal_s16(vacc1i, vout2.val[1], vtw2.val[0]);
vacc2i = vmlal_s16(vacc2i, vout3.val[1], vtw3.val[0]);
int16x4_t vtmp0r = vrshrn_n_s32(vacc0r, 15);
int16x4_t vtmp1r = vrshrn_n_s32(vacc1r, 15);
int16x4_t vtmp2r = vrshrn_n_s32(vacc2r, 15);
int16x4_t vtmp0i = vrshrn_n_s32(vacc0i, 15);
int16x4_t vtmp1i = vrshrn_n_s32(vacc1i, 15);
int16x4_t vtmp2i = vrshrn_n_s32(vacc2i, 15);
const int16x4_t vtmp4r = vsub_s16(vtmp0r, vtmp2r);
const int16x4_t vtmp4i = vsub_s16(vtmp0i, vtmp2i);
const int16x4_t vtmp3r = vadd_s16(vtmp0r, vtmp2r);
const int16x4_t vtmp3i = vadd_s16(vtmp0i, vtmp2i);
const int16x4_t vtmp5r = vsub_s16(vout0.val[0], vtmp1r);
const int16x4_t vtmp5i = vsub_s16(vout0.val[1], vtmp1i);
vout0.val[0] = vadd_s16(vout0.val[0], vtmp1r);
vout0.val[1] = vadd_s16(vout0.val[1], vtmp1i);
vout2.val[0] = vsub_s16(vout0.val[0], vtmp3r);
vout2.val[1] = vsub_s16(vout0.val[1], vtmp3i);
vout0.val[0] = vadd_s16(vout0.val[0], vtmp3r);
vout0.val[1] = vadd_s16(vout0.val[1], vtmp3i);
vout1.val[0] = vadd_s16(vtmp5r, vtmp4i);
vout1.val[1] = vsub_s16(vtmp5i, vtmp4r);
vout3.val[0] = vsub_s16(vtmp5r, vtmp4i);
vout3.val[1] = vadd_s16(vtmp5i, vtmp4r);
vst2_s16(data0, vout0); data0 += 8;
vst2_s16(data1, vout1); data1 += 8;
vst2_s16(data2, vout2); data2 += 8;
vst2_s16(data3, vout3); data3 += 8;
}
} while (--batch != 0);
}