|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#ifndef LYRA_CODEC_SPARSE_MATMUL_COMPUTE_GRU_GATES_AVX_FIXED_H_ |
|
#define LYRA_CODEC_SPARSE_MATMUL_COMPUTE_GRU_GATES_AVX_FIXED_H_ |
|
|
|
#include <cstdint> |
|
#if defined __AVX2__ |
|
#include <immintrin.h> |
|
#endif |
|
#include <vector> |
|
|
|
#include "sparse_matmul/compute/ar_inputs.h" |
|
#include "sparse_matmul/numerics/fast_transcendentals.h" |
|
|
|
namespace csrblocksparse { |
|
|
|
#if defined __AVX2__ |
|
|
|
constexpr int kAVX2SIMDWidth = 8; |
|
|
|
|
|
|
|
|
|
template <bool kTwoInputs> |
|
inline __m256i LoadAndAddFixed32(const int32_t* ptr0, const int32_t* ptr1, |
|
const __m256i& input) { |
|
__m256i data0 = _mm256_load_si256(reinterpret_cast<const __m256i*>(ptr0)); |
|
if (kTwoInputs) { |
|
__m256i data1 = _mm256_load_si256(reinterpret_cast<const __m256i*>(ptr1)); |
|
data0 = _mm256_add_epi32(data0, data1); |
|
} |
|
return _mm256_add_epi32(data0, input); |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
template <bool kTwoInputs> |
|
inline __m256 LoadMultiplyAddToFloat(const int32_t* ptr0, const int32_t* ptr1, |
|
const __m256& float_factor, |
|
const __m256& input) { |
|
__m256i data0 = _mm256_load_si256(reinterpret_cast<const __m256i*>(ptr0)); |
|
if (kTwoInputs) { |
|
__m256i data1 = _mm256_load_si256(reinterpret_cast<const __m256i*>(ptr1)); |
|
data0 = _mm256_add_epi32(data0, data1); |
|
} |
|
__m256 float_result = _mm256_cvtepi32_ps(data0); |
|
float_result = _mm256_mul_ps(float_result, float_factor); |
|
return _mm256_add_ps(float_result, input); |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template <bool kThreeInputs> |
|
inline __m256 MultiplyAddFloat(const __m256& input_pairs, |
|
const __m256& third_input, const float* ptr0_1, |
|
const float* ptr2, const __m256& accumulator) { |
|
__m256 data_pair0 = _mm256_load_ps(ptr0_1); |
|
__m256 data_pair1 = _mm256_load_ps(ptr0_1 + 8); |
|
data_pair0 = _mm256_mul_ps(data_pair0, input_pairs); |
|
data_pair1 = _mm256_mul_ps(data_pair1, input_pairs); |
|
data_pair0 = _mm256_hadd_ps(data_pair0, data_pair1); |
|
|
|
data_pair0 = _mm256_permute4x64_pd((__m256d)data_pair0, 0xd8); |
|
if (kThreeInputs) { |
|
|
|
data_pair1 = _mm256_load_ps(ptr2); |
|
data_pair1 = _mm256_mul_ps(data_pair1, third_input); |
|
data_pair0 = _mm256_add_ps(data_pair0, data_pair1); |
|
} |
|
|
|
return _mm256_add_ps(data_pair0, accumulator); |
|
} |
|
|
|
|
|
template <int kInputMantissaBits, int kStateMantissaBits, bool kSplitGates> |
|
inline __m256i GRUComputeState(const __m256& cell0, const __m256& cell1, |
|
const __m256& reset0, const __m256& reset1, |
|
const __m256& update0, const __m256& update1, |
|
const int32_t* gate_ptr, |
|
const int32_t* gate_other_ptr, |
|
const void* gru_h_ptr) { |
|
|
|
__m256 float_gru0 = LoadMultiplyAddToFloat<kSplitGates>( |
|
gate_ptr, gate_other_ptr, reset0, cell0); |
|
__m256 float_gru1 = LoadMultiplyAddToFloat<kSplitGates>( |
|
gate_ptr + kAVX2SIMDWidth, gate_other_ptr + kAVX2SIMDWidth, reset1, |
|
cell1); |
|
|
|
__m256 hbar0, hbar1; |
|
float_tanh_float<kInputMantissaBits, TM_ORDER4_FLOAT>(float_gru0, float_gru1, |
|
hbar0, hbar1); |
|
|
|
__m256i gru = _mm256_load_si256(reinterpret_cast<__m256i const*>(gru_h_ptr)); |
|
__m256 state_factor = |
|
_mm256_set1_ps(1.0f / (static_cast<float>(1 << kStateMantissaBits))); |
|
float_gru0 = |
|
_mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_castsi256_si128(gru))); |
|
float_gru1 = _mm256_cvtepi32_ps( |
|
_mm256_cvtepi16_epi32(_mm256_extractf128_si256(gru, 1))); |
|
float_gru0 = _mm256_mul_ps(float_gru0, state_factor); |
|
float_gru1 = _mm256_mul_ps(float_gru1, state_factor); |
|
float_gru0 = _mm256_sub_ps(float_gru0, hbar0); |
|
float_gru1 = _mm256_sub_ps(float_gru1, hbar1); |
|
float_gru0 = _mm256_mul_ps(float_gru0, update0); |
|
float_gru1 = _mm256_mul_ps(float_gru1, update1); |
|
state_factor = _mm256_set1_ps(static_cast<float>(1 << kStateMantissaBits)); |
|
float_gru0 = _mm256_add_ps(float_gru0, hbar0); |
|
float_gru1 = _mm256_add_ps(float_gru1, hbar1); |
|
float_gru0 = _mm256_mul_ps(float_gru0, state_factor); |
|
float_gru1 = _mm256_mul_ps(float_gru1, state_factor); |
|
return PackFloatsToFixed16(float_gru0, float_gru1); |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template <bool kTwoGates, ARInputsMode kInputsMode> |
|
inline __m256 GruInput32ToFloat(const __m256& paired_ar, |
|
const __m256& third_ar, |
|
const float* pair_weights, |
|
const float* third_weights, |
|
const int32_t* gates0, const int32_t* gates1, |
|
const int32_t* input) { |
|
__m256i data32 = _mm256_load_si256(reinterpret_cast<__m256i const*>(input)); |
|
data32 = LoadAndAddFixed32<kTwoGates>(gates0, gates1, data32); |
|
__m256 float_data = _mm256_cvtepi32_ps(data32); |
|
if (kInputsMode != ARInputsMode::k0ARInputs) { |
|
float_data = MultiplyAddFloat<kInputsMode == ARInputsMode::k3ARInputs>( |
|
paired_ar, third_ar, pair_weights, third_weights, float_data); |
|
} |
|
return float_data; |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template <int kInputBits, int kStateBits, |
|
ARInputsMode kInputsMode = ARInputsMode::k0ARInputs, |
|
int kReplicas = 1, bool kSplitGates = false> |
|
inline void GruGatesTemplate( |
|
int start, int end, int state_size, int replicas, int replica_stride, |
|
const int32_t* gru_recurrent_ptr, const int32_t* input_ptr, |
|
const std::pair<float, float>* ar_sample01, const float* ar_01_weights, |
|
const float* ar_sample2, const float* ar_2_weights, |
|
const int32_t* gru_recurrent_other_ptr, int16_t* gru_state_ptr) { |
|
constexpr int kQRIncrement = kAVX2SIMDWidth; |
|
|
|
input_ptr += start; |
|
gru_state_ptr += start; |
|
gru_recurrent_ptr += start; |
|
if (kSplitGates) gru_recurrent_other_ptr += start; |
|
__m256 ar_2_inputs, ar_3rd_input; |
|
if (kInputsMode != ARInputsMode::k0ARInputs) { |
|
ar_01_weights += 2 * start; |
|
ar_2_inputs = _mm256_castsi256_ps( |
|
_mm256_set1_epi64x(*reinterpret_cast<const int64_t*>(ar_sample01))); |
|
if (kInputsMode == ARInputsMode::k3ARInputs) { |
|
ar_2_weights += start; |
|
ar_3rd_input = _mm256_set1_ps(*ar_sample2); |
|
} else { |
|
ar_3rd_input = {}; |
|
} |
|
} else { |
|
ar_2_inputs = {}; |
|
ar_3rd_input = {}; |
|
} |
|
|
|
|
|
for (int i = start; i < end; i += kQRIncrement * 2) { |
|
|
|
__m256 reset0 = GruInput32ToFloat<kSplitGates, kInputsMode>( |
|
ar_2_inputs, ar_3rd_input, ar_01_weights, ar_2_weights, |
|
gru_recurrent_ptr, gru_recurrent_other_ptr, input_ptr); |
|
__m256 reset1 = GruInput32ToFloat<kSplitGates, kInputsMode>( |
|
ar_2_inputs, ar_3rd_input, ar_01_weights + 2 * kQRIncrement, |
|
ar_2_weights + kQRIncrement, gru_recurrent_ptr + kAVX2SIMDWidth, |
|
gru_recurrent_other_ptr + kAVX2SIMDWidth, input_ptr + kAVX2SIMDWidth); |
|
float_sigmoid_float<kInputBits>(reset0, reset1); |
|
__m256 update0 = GruInput32ToFloat<kSplitGates, kInputsMode>( |
|
ar_2_inputs, ar_3rd_input, ar_01_weights + 2 * state_size, |
|
ar_2_weights + state_size, gru_recurrent_ptr + state_size, |
|
gru_recurrent_other_ptr + state_size, input_ptr + state_size); |
|
__m256 update1 = GruInput32ToFloat<kSplitGates, kInputsMode>( |
|
ar_2_inputs, ar_3rd_input, |
|
ar_01_weights + 2 * state_size + 2 * kQRIncrement, |
|
ar_2_weights + state_size + kQRIncrement, |
|
gru_recurrent_ptr + state_size + kAVX2SIMDWidth, |
|
gru_recurrent_other_ptr + state_size + kAVX2SIMDWidth, |
|
input_ptr + state_size + kAVX2SIMDWidth); |
|
float_sigmoid_float<kInputBits>(update0, update1); |
|
__m256 cell0 = _mm256_cvtepi32_ps(_mm256_load_si256( |
|
reinterpret_cast<__m256i const*>(input_ptr + 2 * state_size))); |
|
__m256 cell1 = |
|
_mm256_cvtepi32_ps(_mm256_load_si256(reinterpret_cast<__m256i const*>( |
|
input_ptr + 2 * state_size + kAVX2SIMDWidth))); |
|
if (kInputsMode != ARInputsMode::k0ARInputs) { |
|
cell0 = MultiplyAddFloat<kInputsMode == ARInputsMode::k3ARInputs>( |
|
ar_2_inputs, ar_3rd_input, ar_01_weights + 4 * state_size, |
|
ar_2_weights + 2 * state_size, cell0); |
|
cell1 = MultiplyAddFloat<kInputsMode == ARInputsMode::k3ARInputs>( |
|
ar_2_inputs, ar_3rd_input, |
|
ar_01_weights + 4 * state_size + 2 * kQRIncrement, |
|
ar_2_weights + 2 * state_size + kQRIncrement, cell1); |
|
} |
|
__m256i gru_state = GRUComputeState<kInputBits, kStateBits, kSplitGates>( |
|
cell0, cell1, reset0, reset1, update0, update1, |
|
gru_recurrent_ptr + 2 * state_size, |
|
gru_recurrent_other_ptr + 2 * state_size, gru_state_ptr); |
|
if (kReplicas > 0) { |
|
|
|
|
|
for (int j = 0; j < kReplicas; ++j) { |
|
_mm256_store_si256( |
|
reinterpret_cast<__m256i*>(gru_state_ptr + j * replica_stride), |
|
gru_state); |
|
} |
|
} else { |
|
|
|
for (int j = 0; j < replicas; ++j) { |
|
_mm256_store_si256( |
|
reinterpret_cast<__m256i*>(gru_state_ptr + j * replica_stride), |
|
gru_state); |
|
} |
|
} |
|
|
|
input_ptr += 2 * kAVX2SIMDWidth; |
|
gru_state_ptr += 2 * kAVX2SIMDWidth; |
|
gru_recurrent_ptr += 2 * kAVX2SIMDWidth; |
|
if (kSplitGates) gru_recurrent_other_ptr += 2 * kAVX2SIMDWidth; |
|
if (kInputsMode != ARInputsMode::k0ARInputs) { |
|
ar_01_weights += 4 * kQRIncrement; |
|
if (kInputsMode == ARInputsMode::k3ARInputs) |
|
ar_2_weights += 2 * kQRIncrement; |
|
} |
|
} |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template <int kInputBits, int kStateBits, |
|
ARInputsMode kInputsMode = ARInputsMode::k2ARInputs, |
|
bool kSplitGates = false> |
|
inline void GruGatesAVXFixed( |
|
int start, int end, int state_size, const int32_t* gru_recurrent_ptr, |
|
const int32_t* input_ptr, const std::pair<float, float>* ar_sample01, |
|
const float* ar_01_weights, int num_replicas, int replica_stride, |
|
const float* ar_sample2, const float* ar_2_weights, |
|
const int32_t* gru_recurrent_other_ptr, int16_t* gru_state_ptr) { |
|
|
|
|
|
|
|
switch (num_replicas) { |
|
case 1: |
|
GruGatesTemplate<kInputBits, kStateBits, kInputsMode, 1, |
|
kSplitGates>( |
|
start, end, state_size, num_replicas, replica_stride, |
|
gru_recurrent_ptr, input_ptr, ar_sample01, ar_01_weights, ar_sample2, |
|
ar_2_weights, gru_recurrent_other_ptr, gru_state_ptr); |
|
break; |
|
case 2: |
|
GruGatesTemplate<kInputBits, kStateBits, kInputsMode, 2, |
|
kSplitGates>( |
|
start, end, state_size, num_replicas, replica_stride, |
|
gru_recurrent_ptr, input_ptr, ar_sample01, ar_01_weights, ar_sample2, |
|
ar_2_weights, gru_recurrent_other_ptr, gru_state_ptr); |
|
break; |
|
case 4: |
|
GruGatesTemplate<kInputBits, kStateBits, kInputsMode, 4, |
|
kSplitGates>( |
|
start, end, state_size, num_replicas, replica_stride, |
|
gru_recurrent_ptr, input_ptr, ar_sample01, ar_01_weights, ar_sample2, |
|
ar_2_weights, gru_recurrent_other_ptr, gru_state_ptr); |
|
break; |
|
case 6: |
|
GruGatesTemplate<kInputBits, kStateBits, kInputsMode, 6, |
|
kSplitGates>( |
|
start, end, state_size, num_replicas, replica_stride, |
|
gru_recurrent_ptr, input_ptr, ar_sample01, ar_01_weights, ar_sample2, |
|
ar_2_weights, gru_recurrent_other_ptr, gru_state_ptr); |
|
break; |
|
default: |
|
|
|
GruGatesTemplate<kInputBits, kStateBits, kInputsMode, 0, |
|
kSplitGates>( |
|
start, end, state_size, num_replicas, replica_stride, |
|
gru_recurrent_ptr, input_ptr, ar_sample01, ar_01_weights, ar_sample2, |
|
ar_2_weights, gru_recurrent_other_ptr, gru_state_ptr); |
|
} |
|
} |
|
|
|
#endif |
|
|
|
} |
|
|
|
#endif |
|
|