|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#ifndef LYRA_CODEC_SPARSE_MATMUL_COMPUTE_GRU_GATES_ARM_H_ |
|
#define LYRA_CODEC_SPARSE_MATMUL_COMPUTE_GRU_GATES_ARM_H_ |
|
|
|
#if defined __ARM_NEON || defined __aarch64__ |
|
#include <arm_neon.h> |
|
#endif |
|
#include <cstdint> |
|
|
|
#include "sparse_matmul/compute/ar_inputs.h" |
|
#include "sparse_matmul/numerics/fast_transcendentals.h" |
|
|
|
namespace csrblocksparse { |
|
|
|
static constexpr int kNeonSIMDWidth = 4; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#if defined __ARM_NEON || defined __aarch64__ |
|
|
|
#if !defined __aarch64__ |
|
|
|
inline float32x4_t vpaddq_f32(float32x4_t a, float32x4_t b) { |
|
float32x2_t a10 = vget_low_f32(a); |
|
float32x2_t a32 = vget_high_f32(a); |
|
float32x2_t b10 = vget_low_f32(b); |
|
float32x2_t b32 = vget_high_f32(b); |
|
return vcombine_f32(vpadd_f32(a10, a32), vpadd_f32(b10, b32)); |
|
} |
|
#endif |
|
|
|
template <ARInputsMode kInputsMode, bool SplitGates> |
|
void GoThroughGatesFloat(int start, int end, const float* qr_ptr, |
|
const float* gru_gates_ptr, |
|
const float* gru_gates_other_ptr, |
|
const float* conditioning_ptr, float* gru_h_ptr, |
|
const float* w_hat, int proj_size, |
|
const float* coarse_at_sminus1, |
|
const float* fine_at_sminus1, |
|
const float* coarse_at_s) { |
|
|
|
conditioning_ptr += start; |
|
gru_h_ptr += start; |
|
gru_gates_ptr += start; |
|
if (SplitGates) { |
|
DCHECK_NE(gru_gates_other_ptr, nullptr); |
|
gru_gates_other_ptr += start; |
|
} |
|
if (kInputsMode != ARInputsMode::k0ARInputs) { |
|
DCHECK_NE(qr_ptr, nullptr); |
|
qr_ptr += 2 * start; |
|
DCHECK_NE(coarse_at_sminus1, nullptr); |
|
DCHECK_NE(fine_at_sminus1, nullptr); |
|
if (kInputsMode == ARInputsMode::k3ARInputs) { |
|
DCHECK_NE(w_hat, nullptr); |
|
DCHECK_NE(coarse_at_s, nullptr); |
|
w_hat += start; |
|
} |
|
} |
|
for (int i = start; i < end; i += kNeonSIMDWidth) { |
|
float32x4_t reset = vld1q_f32(gru_gates_ptr); |
|
float32x4_t update = vld1q_f32(gru_gates_ptr + proj_size); |
|
float32x4_t cell = vld1q_f32(gru_gates_ptr + 2 * proj_size); |
|
float32x4_t qr_cell; |
|
if (SplitGates) { |
|
reset = vaddq_f32(reset, vld1q_f32(gru_gates_other_ptr)); |
|
update = vaddq_f32(update, vld1q_f32(gru_gates_other_ptr + proj_size)); |
|
cell = vaddq_f32(cell, vld1q_f32(gru_gates_other_ptr + 2 * proj_size)); |
|
} |
|
if (kInputsMode != ARInputsMode::k0ARInputs) { |
|
|
|
float32x4_t sample = vdupq_n_f32(*coarse_at_sminus1); |
|
sample = vsetq_lane_f32(*fine_at_sminus1, sample, 1); |
|
sample = vsetq_lane_f32(*fine_at_sminus1, sample, 3); |
|
|
|
|
|
|
|
auto qr_reset_0 = vmulq_f32(vld1q_f32(qr_ptr), sample); |
|
auto qr_reset_1 = vmulq_f32(vld1q_f32(qr_ptr + 4), sample); |
|
auto qr_reset = vpaddq_f32(qr_reset_0, qr_reset_1); |
|
|
|
auto qr_update_0 = vmulq_f32(vld1q_f32(qr_ptr + 2 * proj_size), sample); |
|
auto qr_update_1 = |
|
vmulq_f32(vld1q_f32(qr_ptr + 4 + 2 * proj_size), sample); |
|
auto qr_update = vpaddq_f32(qr_update_0, qr_update_1); |
|
|
|
auto qr_cell_0 = vmulq_f32(vld1q_f32(qr_ptr + 4 * proj_size), sample); |
|
auto qr_cell_1 = vmulq_f32(vld1q_f32(qr_ptr + 4 + 4 * proj_size), sample); |
|
qr_cell = vpaddq_f32(qr_cell_0, qr_cell_1); |
|
|
|
if (kInputsMode == ARInputsMode::k3ARInputs) { |
|
float32x4_t w_sample = vdupq_n_f32(*coarse_at_s); |
|
qr_reset = vmlaq_f32(qr_reset, vld1q_f32(w_hat), w_sample); |
|
qr_update = |
|
vmlaq_f32(qr_update, vld1q_f32(w_hat + proj_size), w_sample); |
|
qr_cell = |
|
vmlaq_f32(qr_cell, vld1q_f32(w_hat + 2 * proj_size), w_sample); |
|
} |
|
reset = vaddq_f32(reset, qr_reset); |
|
update = vaddq_f32(update, qr_update); |
|
} |
|
auto reset_conditioning = vld1q_f32(conditioning_ptr); |
|
auto update_conditioning = vld1q_f32(conditioning_ptr + proj_size); |
|
auto cell_conditioning = vld1q_f32(conditioning_ptr + 2 * proj_size); |
|
|
|
reset = fast_sigmoid(vaddq_f32(reset, reset_conditioning)); |
|
update = fast_sigmoid(vaddq_f32(update, update_conditioning)); |
|
if (kInputsMode == ARInputsMode::k0ARInputs) { |
|
cell = vmulq_f32(reset, cell); |
|
} else { |
|
cell = vmlaq_f32(qr_cell, reset, cell); |
|
} |
|
auto hbar = fast_tanh(vaddq_f32(cell, cell_conditioning)); |
|
|
|
auto prev_h = vld1q_f32(gru_h_ptr); |
|
auto diff = vsubq_f32(prev_h, hbar); |
|
auto new_h = vmlaq_f32(hbar, diff, update); |
|
|
|
vst1q_f32(gru_h_ptr, new_h); |
|
|
|
conditioning_ptr += kNeonSIMDWidth; |
|
gru_h_ptr += kNeonSIMDWidth; |
|
gru_gates_ptr += kNeonSIMDWidth; |
|
if (SplitGates) gru_gates_other_ptr += kNeonSIMDWidth; |
|
if (kInputsMode != ARInputsMode::k0ARInputs) { |
|
qr_ptr += 2 * kNeonSIMDWidth; |
|
if (kInputsMode == ARInputsMode::k3ARInputs) w_hat += kNeonSIMDWidth; |
|
} |
|
} |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
template <typename GRUStateType, typename GRUMatMulOutType, |
|
ARInputsMode kInputsMode, bool SplitGates> |
|
void GoThroughGatesFixed(int start, int end, const float* qr_ptr, |
|
const int32_t* gru_gates_ptr, |
|
const int32_t* gru_gates_other_ptr, |
|
const int32_t* conditioning_ptr, int16_t* gru_h_ptr, |
|
const float* w_hat, int proj_size, |
|
const std::pair<float, float>* ar_at_sminus1, |
|
const float* coarse_at_s) { |
|
|
|
conditioning_ptr += start; |
|
gru_h_ptr += start; |
|
gru_gates_ptr += start; |
|
if (SplitGates) { |
|
DCHECK_NE(gru_gates_other_ptr, nullptr); |
|
gru_gates_other_ptr += start; |
|
} |
|
float32x4_t sample01; |
|
float32x4_t w_sample; |
|
if (kInputsMode != ARInputsMode::k0ARInputs) { |
|
DCHECK_NE(qr_ptr, nullptr); |
|
qr_ptr += 2 * start; |
|
DCHECK_NE(ar_at_sminus1, nullptr); |
|
sample01 = vdupq_n_f32(ar_at_sminus1->first); |
|
sample01 = vsetq_lane_f32(ar_at_sminus1->second, sample01, 1); |
|
sample01 = vsetq_lane_f32(ar_at_sminus1->second, sample01, 3); |
|
if (kInputsMode == ARInputsMode::k3ARInputs) { |
|
DCHECK_NE(w_hat, nullptr); |
|
DCHECK_NE(coarse_at_s, nullptr); |
|
w_hat += start; |
|
w_sample = vdupq_n_f32(*coarse_at_s); |
|
} |
|
} |
|
for (int i = start; i < end; i += kNeonSIMDWidth) { |
|
auto reset = vld1q_s32(gru_gates_ptr); |
|
auto update = vld1q_s32(gru_gates_ptr + proj_size); |
|
|
|
auto cell_int = vld1q_s32(gru_gates_ptr + 2 * proj_size); |
|
if (SplitGates) { |
|
reset = vaddq_s32(reset, vld1q_s32(gru_gates_other_ptr)); |
|
update = vaddq_s32(update, vld1q_s32(gru_gates_other_ptr + proj_size)); |
|
cell_int = |
|
vaddq_s32(cell_int, vld1q_s32(gru_gates_other_ptr + 2 * proj_size)); |
|
} |
|
float32x4_t cell = |
|
vcvtq_n_f32_s32(cell_int, GRUMatMulOutType::kMantissaBits); |
|
float32x4_t qr_cell; |
|
if (kInputsMode != ARInputsMode::k0ARInputs) { |
|
|
|
float32x4_t qr_reset_0 = vmulq_f32(vld1q_f32(qr_ptr), sample01); |
|
float32x4_t qr_reset_1 = vmulq_f32(vld1q_f32(qr_ptr + 4), sample01); |
|
float32x4_t qr_reset = vpaddq_f32(qr_reset_0, qr_reset_1); |
|
|
|
float32x4_t qr_update_0 = |
|
vmulq_f32(vld1q_f32(qr_ptr + 2 * proj_size), sample01); |
|
float32x4_t qr_update_1 = |
|
vmulq_f32(vld1q_f32(qr_ptr + 4 + 2 * proj_size), sample01); |
|
float32x4_t qr_update = vpaddq_f32(qr_update_0, qr_update_1); |
|
|
|
float32x4_t qr_cell_0 = |
|
vmulq_f32(vld1q_f32(qr_ptr + 4 * proj_size), sample01); |
|
float32x4_t qr_cell_1 = |
|
vmulq_f32(vld1q_f32(qr_ptr + 4 + 4 * proj_size), sample01); |
|
qr_cell = vpaddq_f32(qr_cell_0, qr_cell_1); |
|
if (kInputsMode == ARInputsMode::k3ARInputs) { |
|
float32x4_t w_sample = vdupq_n_f32(*coarse_at_s); |
|
qr_reset = vmlaq_f32(qr_reset, vld1q_f32(w_hat), w_sample); |
|
qr_update = |
|
vmlaq_f32(qr_update, vld1q_f32(w_hat + proj_size), w_sample); |
|
qr_cell = |
|
vmlaq_f32(qr_cell, vld1q_f32(w_hat + 2 * proj_size), w_sample); |
|
} |
|
reset = vaddq_s32( |
|
reset, vcvtq_n_s32_f32(qr_reset, GRUMatMulOutType::kMantissaBits)); |
|
update = vaddq_s32( |
|
update, vcvtq_n_s32_f32(qr_update, GRUMatMulOutType::kMantissaBits)); |
|
} |
|
|
|
auto reset_conditioning = vld1q_s32(conditioning_ptr); |
|
auto update_conditioning = vld1q_s32(conditioning_ptr + proj_size); |
|
float32x4_t cell_conditioning = |
|
vcvtq_n_f32_s32(vld1q_s32(conditioning_ptr + 2 * proj_size), |
|
GRUMatMulOutType::kMantissaBits); |
|
|
|
float32x4_t reset_f32 = fast_sigmoid<GRUMatMulOutType::kExponentBits>( |
|
vaddq_s32(reset, reset_conditioning)); |
|
float32x4_t update_f32 = fast_sigmoid<GRUMatMulOutType::kExponentBits>( |
|
vaddq_s32(update, update_conditioning)); |
|
if (kInputsMode == ARInputsMode::k0ARInputs) { |
|
cell = vmulq_f32(reset_f32, cell); |
|
} else { |
|
cell = vmlaq_f32(qr_cell, reset_f32, cell); |
|
} |
|
float32x4_t hbar = fast_tanh(vaddq_f32(cell, cell_conditioning)); |
|
|
|
float32x4_t prev_h = vcvtq_n_f32_s32(vmovl_s16(vld1_s16(gru_h_ptr)), |
|
GRUStateType::kMantissaBits); |
|
float32x4_t diff = vsubq_f32(prev_h, hbar); |
|
float32x4_t new_h = vmlaq_f32(hbar, diff, update_f32); |
|
|
|
|
|
|
|
|
|
vst1_s16(gru_h_ptr, |
|
vqrshrn_n_s32( |
|
vcvtq_n_s32_f32(new_h, GRUStateType::kMantissaBits + 16), 16)); |
|
|
|
conditioning_ptr += kNeonSIMDWidth; |
|
gru_h_ptr += kNeonSIMDWidth; |
|
gru_gates_ptr += kNeonSIMDWidth; |
|
if (SplitGates) gru_gates_other_ptr += kNeonSIMDWidth; |
|
if (kInputsMode != ARInputsMode::k0ARInputs) { |
|
qr_ptr += 2 * kNeonSIMDWidth; |
|
if (kInputsMode == ARInputsMode::k3ARInputs) w_hat += kNeonSIMDWidth; |
|
} |
|
} |
|
} |
|
#endif |
|
|
|
} |
|
|
|
#endif |
|
|