File size: 5,111 Bytes
8b7c501
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
// Copyright 2022 Google LLC
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.

#include <assert.h>
#include <stddef.h>
#include <stdint.h>

#include <xnnpack/math.h>
#include <xnnpack/fft.h>

#include <arm_neon.h>

void xnn_cs16_bfly4_ukernel__neon_x4(
    size_t batch,
    size_t samples,
    int16_t* data,
    const int16_t* twiddle,
    size_t stride)
{
  assert(batch != 0);
  assert(samples != 0);
  assert(samples % (sizeof(int16_t) * 8) == 0);
  assert(data != NULL);
  assert(stride != 0);
  assert(twiddle != NULL);

  const int16x4_t vdiv4 = vdup_n_s16(8191);

  int16_t* data3 = data;
  do {
    int16_t* data0 = data3;
    int16_t* data1 = (int16_t*) ((uintptr_t) data0 + samples);
    int16_t* data2 = (int16_t*) ((uintptr_t) data1 + samples);
    data3 = (int16_t*) ((uintptr_t) data2 + samples);

    const int16_t* tw1 = twiddle;
    const int16_t* tw2 = twiddle;
    const int16_t* tw3 = twiddle;

    size_t s = samples;
    for (; s >= sizeof(int16_t) * 8; s -= sizeof(int16_t) * 8) {
      int16x4x2_t vout0 = vld2_s16(data0);
      int16x4x2_t vout1 = vld2_s16(data1);
      int16x4x2_t vout2 = vld2_s16(data2);
      int16x4x2_t vout3 = vld2_s16(data3);

      int16x4x2_t vtw1 = vld2_dup_s16(tw1);
      int16x4x2_t vtw2 = vld2_dup_s16(tw2);
      int16x4x2_t vtw3 = vld2_dup_s16(tw3);
      tw1 = (const int16_t*) ((uintptr_t) tw1 + stride);
      tw2 = (const int16_t*) ((uintptr_t) tw2 + stride * 2);
      tw3 = (const int16_t*) ((uintptr_t) tw3 + stride * 3);
      vtw1 = vld2_lane_s16(tw1, vtw1, 1);
      vtw2 = vld2_lane_s16(tw2, vtw2, 1);
      vtw3 = vld2_lane_s16(tw3, vtw3, 1);
      tw1 = (const int16_t*) ((uintptr_t) tw1 + stride);
      tw2 = (const int16_t*) ((uintptr_t) tw2 + stride * 2);
      tw3 = (const int16_t*) ((uintptr_t) tw3 + stride * 3);
      vtw1 = vld2_lane_s16(tw1, vtw1, 2);
      vtw2 = vld2_lane_s16(tw2, vtw2, 2);
      vtw3 = vld2_lane_s16(tw3, vtw3, 2);
      tw1 = (const int16_t*) ((uintptr_t) tw1 + stride);
      tw2 = (const int16_t*) ((uintptr_t) tw2 + stride * 2);
      tw3 = (const int16_t*) ((uintptr_t) tw3 + stride * 3);
      vtw1 = vld2_lane_s16(tw1, vtw1, 3);
      vtw2 = vld2_lane_s16(tw2, vtw2, 3);
      vtw3 = vld2_lane_s16(tw3, vtw3, 3);
      tw1 = (const int16_t*) ((uintptr_t) tw1 + stride);
      tw2 = (const int16_t*) ((uintptr_t) tw2 + stride * 2);
      tw3 = (const int16_t*) ((uintptr_t) tw3 + stride * 3);

      // Note 32767 / 4 = 8191.  Should be 8192.
      vout1.val[0] = vqrdmulh_s16(vout1.val[0], vdiv4);
      vout1.val[1] = vqrdmulh_s16(vout1.val[1], vdiv4);
      vout2.val[0] = vqrdmulh_s16(vout2.val[0], vdiv4);
      vout2.val[1] = vqrdmulh_s16(vout2.val[1], vdiv4);
      vout3.val[0] = vqrdmulh_s16(vout3.val[0], vdiv4);
      vout3.val[1] = vqrdmulh_s16(vout3.val[1], vdiv4);
      vout0.val[0] = vqrdmulh_s16(vout0.val[0], vdiv4);
      vout0.val[1] = vqrdmulh_s16(vout0.val[1], vdiv4);

      int32x4_t vacc0r = vmull_s16(vout1.val[0], vtw1.val[0]);
      int32x4_t vacc1r = vmull_s16(vout2.val[0], vtw2.val[0]);
      int32x4_t vacc2r = vmull_s16(vout3.val[0], vtw3.val[0]);
      int32x4_t vacc0i = vmull_s16(vout1.val[0], vtw1.val[1]);
      int32x4_t vacc1i = vmull_s16(vout2.val[0], vtw2.val[1]);
      int32x4_t vacc2i = vmull_s16(vout3.val[0], vtw3.val[1]);
      vacc0r = vmlsl_s16(vacc0r, vout1.val[1], vtw1.val[1]);
      vacc1r = vmlsl_s16(vacc1r, vout2.val[1], vtw2.val[1]);
      vacc2r = vmlsl_s16(vacc2r, vout3.val[1], vtw3.val[1]);
      vacc0i = vmlal_s16(vacc0i, vout1.val[1], vtw1.val[0]);
      vacc1i = vmlal_s16(vacc1i, vout2.val[1], vtw2.val[0]);
      vacc2i = vmlal_s16(vacc2i, vout3.val[1], vtw3.val[0]);
      int16x4_t vtmp0r = vrshrn_n_s32(vacc0r, 15);
      int16x4_t vtmp1r = vrshrn_n_s32(vacc1r, 15);
      int16x4_t vtmp2r = vrshrn_n_s32(vacc2r, 15);
      int16x4_t vtmp0i = vrshrn_n_s32(vacc0i, 15);
      int16x4_t vtmp1i = vrshrn_n_s32(vacc1i, 15);
      int16x4_t vtmp2i = vrshrn_n_s32(vacc2i, 15);

      const int16x4_t vtmp4r = vsub_s16(vtmp0r, vtmp2r);
      const int16x4_t vtmp4i = vsub_s16(vtmp0i, vtmp2i);
      const int16x4_t vtmp3r = vadd_s16(vtmp0r, vtmp2r);
      const int16x4_t vtmp3i = vadd_s16(vtmp0i, vtmp2i);

      const int16x4_t vtmp5r = vsub_s16(vout0.val[0], vtmp1r);
      const int16x4_t vtmp5i = vsub_s16(vout0.val[1], vtmp1i);
      vout0.val[0] = vadd_s16(vout0.val[0], vtmp1r);
      vout0.val[1] = vadd_s16(vout0.val[1], vtmp1i);

      vout2.val[0] = vsub_s16(vout0.val[0], vtmp3r);
      vout2.val[1] = vsub_s16(vout0.val[1], vtmp3i);
      vout0.val[0] = vadd_s16(vout0.val[0], vtmp3r);
      vout0.val[1] = vadd_s16(vout0.val[1], vtmp3i);

      vout1.val[0] = vadd_s16(vtmp5r, vtmp4i);
      vout1.val[1] = vsub_s16(vtmp5i, vtmp4r);
      vout3.val[0] = vsub_s16(vtmp5r, vtmp4i);
      vout3.val[1] = vadd_s16(vtmp5i, vtmp4r);

      vst2_s16(data0, vout0);  data0 += 8;
      vst2_s16(data1, vout1);  data1 += 8;
      vst2_s16(data2, vout2);  data2 += 8;
      vst2_s16(data3, vout3);  data3 += 8;
    }
  } while (--batch != 0);
}