|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#ifndef LYRA_CODEC_SPARSE_MATMUL_COMPUTE_KERNELS_AVX_H_ |
|
#define LYRA_CODEC_SPARSE_MATMUL_COMPUTE_KERNELS_AVX_H_ |
|
|
|
#if defined __AVX__ |
|
#include <immintrin.h> |
|
|
|
#include <algorithm> |
|
#include <type_traits> |
|
|
|
#include "sparse_matmul/numerics/fast_transcendentals.h" |
|
#include "sparse_matmul/numerics/fixed_types.h" |
|
#include "sparse_matmul/numerics/float16_types.h" |
|
#include "sparse_matmul/numerics/type_utils.h" |
|
|
|
namespace csrblocksparse { |
|
namespace detail { |
|
|
|
template <typename WeightType, typename RhsType, typename OutType> |
|
struct IsAllowableFloatTypes |
|
: std::integral_constant<bool, std::is_same<WeightType, float>::value && |
|
std::is_same<RhsType, float>::value && |
|
std::is_same<OutType, float>::value> {}; |
|
|
|
#if defined __AVX2__ |
|
|
|
|
|
|
|
template <typename WeightType, typename RhsType, typename OutType> |
|
struct IsAllowableFixedTypes |
|
: std::integral_constant<bool, (IsFixed16Type<WeightType>::value && |
|
IsFixed16Type<RhsType>::value) && |
|
(IsFixed32Type<OutType>::value || |
|
IsFixed16Type<OutType>::value)> {}; |
|
|
|
template <typename WeightType, typename RhsType, typename OutType> |
|
struct ShouldEnableGenericKernel |
|
: std::integral_constant< |
|
bool, |
|
!IsAllowableFloatTypes<WeightType, RhsType, OutType>::value && |
|
!IsAllowableFixedTypes<WeightType, RhsType, OutType>::value> {}; |
|
|
|
template <typename Type> |
|
struct IsAddableFixedTypes |
|
: std::integral_constant<bool, IsFixed32Type<Type>::value || |
|
IsFixed16Type<Type>::value> {}; |
|
template <typename Type> |
|
struct ShouldEnableGenericAdd |
|
: std::integral_constant<bool, !IsAddableFixedTypes<Type>::value> {}; |
|
|
|
#else |
|
|
|
template <typename WeightType, typename RhsType, typename OutType> |
|
struct ShouldEnableGenericKernel |
|
: std::integral_constant< |
|
bool, !IsAllowableFloatTypes<WeightType, RhsType, OutType>::value> {}; |
|
|
|
template <typename Type> |
|
struct ShouldEnableGenericAdd : std::true_type {}; |
|
#endif |
|
|
|
template <typename WeightType, typename RhsType, typename OutType> |
|
struct ShouldEnableGenericSpMV_4x4 |
|
: ShouldEnableGenericKernel<WeightType, RhsType, OutType> {}; |
|
template <typename WeightType, typename RhsType, typename OutType> |
|
struct ShouldEnableGenericSpMM5_4x4 |
|
: ShouldEnableGenericKernel<WeightType, RhsType, OutType> {}; |
|
template <typename WeightType, typename RhsType, typename OutType> |
|
struct ShouldEnableGenericSpMV_1x1 : std::true_type {}; |
|
template <typename WeightType, typename RhsType, typename OutType> |
|
struct ShouldEnableGenericSpMM5_1x1 : std::true_type {}; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
inline void Extract4Results(bool relu, __m256& sum1, __m256& sum2, |
|
float** out_ptr) { |
|
|
|
|
|
sum1 = _mm256_hadd_ps(sum1, sum2); |
|
sum1 = _mm256_hadd_ps(sum1, sum1); |
|
|
|
|
|
if (relu) { |
|
sum1 = _mm256_max_ps(sum1, _mm256_setzero_ps()); |
|
} |
|
|
|
|
|
|
|
sum2 = _mm256_permute2f128_ps(sum1, sum1, 1); |
|
|
|
sum1 = _mm256_unpacklo_ps(sum1, sum2); |
|
|
|
__m128 result = _mm256_extractf128_ps(sum1, 0); |
|
_mm_store_ps(*out_ptr, result); |
|
*out_ptr += 4; |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template <typename WeightType, typename RhsType, typename OutType> |
|
typename std::enable_if<std::is_same<WeightType, float>::value && |
|
std::is_same<RhsType, float>::value && |
|
std::is_same<OutType, float>::value>::type |
|
SpMV_4x4(const WeightType* weights_ptr, const int16_t* col_deltas_bytes, |
|
const int32_t* nnz_per_row, const RhsType* rhs_ptr, |
|
const typename TypeOfProduct<WeightType, RhsType>::type* bias_ptr, |
|
OutType* out_ptr, int64_t assigned_rows, |
|
int64_t rows , |
|
int64_t cols , int relu) { |
|
for (int reduced_row = 0; reduced_row < assigned_rows; ++reduced_row) { |
|
|
|
__m256 sum1 = _mm256_set_m128(_mm_broadcast_ss(bias_ptr + 1), |
|
_mm_broadcast_ss(bias_ptr)); |
|
bias_ptr += 2; |
|
__m256 sum2 = _mm256_set_m128(_mm_broadcast_ss(bias_ptr + 1), |
|
_mm_broadcast_ss(bias_ptr)); |
|
bias_ptr += 2; |
|
|
|
int reduced_col_count = *nnz_per_row++; |
|
for (int c = 0; c < reduced_col_count; ++c) { |
|
int col_delta = *col_deltas_bytes++ / sizeof(RhsType); |
|
rhs_ptr += col_delta; |
|
|
|
__m256 rhs = |
|
_mm256_broadcast_ps(reinterpret_cast<const __m128*>(rhs_ptr)); |
|
__m256 weights1 = _mm256_load_ps(weights_ptr); |
|
weights_ptr += 8; |
|
sum1 = _mm256_add_ps(sum1, _mm256_mul_ps(weights1, rhs)); |
|
__m256 weights2 = _mm256_load_ps(weights_ptr); |
|
weights_ptr += 8; |
|
sum2 = _mm256_add_ps(sum2, _mm256_mul_ps(weights2, rhs)); |
|
} |
|
Extract4Results(relu, sum1, sum2, &out_ptr); |
|
} |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template <typename WeightType, typename RhsType, typename OutType> |
|
typename std::enable_if<std::is_same<WeightType, float>::value && |
|
std::is_same<RhsType, float>::value && |
|
std::is_same<OutType, float>::value>::type |
|
SpMM5_4x4(const WeightType* weights_ptr, const int16_t* col_deltas_bytes, |
|
const int32_t* nnz_per_row, const RhsType* rhs_ptr, |
|
const typename TypeOfProduct<WeightType, RhsType>::type* bias_ptr, |
|
OutType* out_ptr, int64_t assigned_rows, int64_t rows, int64_t cols, |
|
int relu) { |
|
const RhsType* rhs_ptrs[5]; |
|
for (int i = 0; i < 5; ++i) rhs_ptrs[i] = rhs_ptr + i * cols; |
|
|
|
OutType* out_ptrs[5]; |
|
for (int i = 0; i < 5; ++i) out_ptrs[i] = out_ptr + i * rows; |
|
|
|
for (int reduced_row = 0; reduced_row < assigned_rows; ++reduced_row) { |
|
|
|
|
|
__m256 sum1_0 = _mm256_set_m128(_mm_broadcast_ss(bias_ptr + 1), |
|
_mm_broadcast_ss(bias_ptr)); |
|
bias_ptr += 2; |
|
__m256 sum2_0 = _mm256_set_m128(_mm_broadcast_ss(bias_ptr + 1), |
|
_mm_broadcast_ss(bias_ptr)); |
|
bias_ptr += 2; |
|
__m256 sum1_1 = sum1_0; |
|
__m256 sum2_1 = sum2_0; |
|
__m256 sum1_2 = sum1_0; |
|
__m256 sum2_2 = sum2_0; |
|
__m256 sum1_3 = sum1_0; |
|
__m256 sum2_3 = sum2_0; |
|
__m256 sum1_4 = sum1_0; |
|
__m256 sum2_4 = sum2_0; |
|
|
|
int reduced_col_count = *nnz_per_row++; |
|
for (int c = 0; c < reduced_col_count; ++c) { |
|
int col_delta = *col_deltas_bytes++ / sizeof(RhsType); |
|
for (int k = 0; k < 5; ++k) rhs_ptrs[k] += col_delta; |
|
|
|
|
|
__m256 rhs = |
|
_mm256_broadcast_ps(reinterpret_cast<const __m128*>(rhs_ptrs[0])); |
|
__m256 weights1 = _mm256_load_ps(weights_ptr); |
|
weights_ptr += 8; |
|
sum1_0 = _mm256_add_ps(sum1_0, _mm256_mul_ps(weights1, rhs)); |
|
__m256 weights2 = _mm256_load_ps(weights_ptr); |
|
weights_ptr += 8; |
|
sum2_0 = _mm256_add_ps(sum2_0, _mm256_mul_ps(weights2, rhs)); |
|
rhs = _mm256_broadcast_ps(reinterpret_cast<const __m128*>(rhs_ptrs[1])); |
|
sum1_1 = _mm256_add_ps(sum1_1, _mm256_mul_ps(weights1, rhs)); |
|
sum2_1 = _mm256_add_ps(sum2_1, _mm256_mul_ps(weights2, rhs)); |
|
rhs = _mm256_broadcast_ps(reinterpret_cast<const __m128*>(rhs_ptrs[2])); |
|
sum1_2 = _mm256_add_ps(sum1_2, _mm256_mul_ps(weights1, rhs)); |
|
sum2_2 = _mm256_add_ps(sum2_2, _mm256_mul_ps(weights2, rhs)); |
|
rhs = _mm256_broadcast_ps(reinterpret_cast<const __m128*>(rhs_ptrs[3])); |
|
sum1_3 = _mm256_add_ps(sum1_3, _mm256_mul_ps(weights1, rhs)); |
|
sum2_3 = _mm256_add_ps(sum2_3, _mm256_mul_ps(weights2, rhs)); |
|
rhs = _mm256_broadcast_ps(reinterpret_cast<const __m128*>(rhs_ptrs[4])); |
|
sum1_4 = _mm256_add_ps(sum1_4, _mm256_mul_ps(weights1, rhs)); |
|
sum2_4 = _mm256_add_ps(sum2_4, _mm256_mul_ps(weights2, rhs)); |
|
} |
|
|
|
Extract4Results(relu, sum1_0, sum2_0, &out_ptrs[0]); |
|
Extract4Results(relu, sum1_1, sum2_1, &out_ptrs[1]); |
|
Extract4Results(relu, sum1_2, sum2_2, &out_ptrs[2]); |
|
Extract4Results(relu, sum1_3, sum2_3, &out_ptrs[3]); |
|
Extract4Results(relu, sum1_4, sum2_4, &out_ptrs[4]); |
|
} |
|
} |
|
|
|
#ifdef __AVX2__ |
|
|
|
|
|
|
|
inline void Compute4Results(bool relu, int kShiftAmount, __m256i& sum) { |
|
|
|
|
|
|
|
sum = _mm256_hadd_epi32(sum, sum); |
|
|
|
sum = _mm256_permute4x64_epi64(sum, 0xd8); |
|
if (kShiftAmount > 0) { |
|
|
|
__m256i rounding = _mm256_set1_epi32(1 << (kShiftAmount - 1)); |
|
sum = _mm256_add_epi32(sum, rounding); |
|
sum = _mm256_srai_epi32(sum, kShiftAmount); |
|
} |
|
|
|
|
|
if (relu) { |
|
sum = _mm256_max_epi32(sum, _mm256_setzero_si256()); |
|
} |
|
} |
|
|
|
|
|
|
|
inline void Extract4xint32(bool relu, int kShiftAmount, __m256i& sum, |
|
int32_t** out_ptr) { |
|
Compute4Results(relu, kShiftAmount, sum); |
|
|
|
__m128i result = _mm256_extractf128_si256(sum, 0); |
|
_mm_store_si128(reinterpret_cast<__m128i*>(*out_ptr), result); |
|
*out_ptr += 4; |
|
} |
|
|
|
|
|
|
|
|
|
inline void Extract4xint16(bool relu, int kShiftAmount, __m256i& sum, |
|
int16_t** out_ptr) { |
|
Compute4Results(relu, kShiftAmount, sum); |
|
|
|
|
|
|
|
sum = _mm256_packs_epi32(sum, sum); |
|
|
|
*reinterpret_cast<int64_t*>(*out_ptr) = _mm256_extract_epi64(sum, 0); |
|
*out_ptr += 4; |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template <typename WeightType, typename RhsType, typename OutType> |
|
typename std::enable_if< |
|
IsFixed16Type<WeightType>::value && IsFixed16Type<RhsType>::value && |
|
(IsFixed32Type<OutType>::value || IsFixed16Type<OutType>::value)>::type |
|
SpMV_4x4(const WeightType* weights_ptr, const int16_t* col_deltas_bytes, |
|
const int32_t* nnz_per_row, const RhsType* rhs_ptr, |
|
const typename TypeOfProduct<WeightType, RhsType>::type* bias_ptr, |
|
OutType* out_ptr, int64_t assigned_rows, |
|
int64_t rows , |
|
int64_t cols , int relu) { |
|
constexpr int kShiftAmount = |
|
TypeOfProduct<WeightType, RhsType>::type::kMantissaBits - |
|
OutType::kMantissaBits; |
|
static_assert(kShiftAmount >= 0, |
|
"Result must have fewer mantissa bits than product"); |
|
for (int reduced_row = 0; reduced_row < assigned_rows; ++reduced_row) { |
|
|
|
__m128i bias = _mm_load_si128(reinterpret_cast<__m128i const*>(bias_ptr)); |
|
__m256i biases = _mm256_set_m128i(bias, bias); |
|
bias_ptr += 4; |
|
|
|
|
|
|
|
biases = _mm256_permute4x64_epi64(biases, 0xb4); |
|
|
|
biases = _mm256_unpacklo_epi32(biases, biases); |
|
|
|
|
|
__m256i sum = _mm256_add_epi32(biases, biases); |
|
|
|
|
|
|
|
|
|
int reduced_col_count = *nnz_per_row; |
|
++nnz_per_row; |
|
for (int c = 0; c < reduced_col_count; ++c) { |
|
int col_delta = *col_deltas_bytes++ / sizeof(RhsType); |
|
rhs_ptr += col_delta; |
|
|
|
|
|
__m128i rhs_64 = |
|
_mm_loadl_epi64(reinterpret_cast<__m128i const*>(rhs_ptr)); |
|
|
|
__m256i weights = |
|
_mm256_load_si256(reinterpret_cast<__m256i const*>(weights_ptr)); |
|
|
|
|
|
__m256i rhs = _mm256_broadcastq_epi64(rhs_64); |
|
weights_ptr += 16; |
|
|
|
|
|
sum = _mm256_add_epi32(sum, _mm256_madd_epi16(weights, rhs)); |
|
} |
|
static_assert( |
|
IsFixed16Type<OutType>::value || IsFixed32Type<OutType>::value, |
|
"AVX2 kernel only supports fixed16 and fixed32 types"); |
|
|
|
|
|
if (IsFixed32Type<OutType>::value) { |
|
Extract4xint32(relu, kShiftAmount, sum, |
|
reinterpret_cast<int32_t**>(&out_ptr)); |
|
} else { |
|
Extract4xint16(relu, kShiftAmount, sum, |
|
reinterpret_cast<int16_t**>(&out_ptr)); |
|
} |
|
} |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template <typename WeightType, typename RhsType, typename OutType> |
|
typename std::enable_if< |
|
IsFixed16Type<WeightType>::value && IsFixed16Type<RhsType>::value && |
|
(IsFixed32Type<OutType>::value || IsFixed16Type<OutType>::value)>::type |
|
SpMM5_4x4(const WeightType* weights_ptr, const int16_t* col_deltas_bytes, |
|
const int32_t* nnz_per_row, const RhsType* rhs_ptr, |
|
const typename TypeOfProduct<WeightType, RhsType>::type* bias_ptr, |
|
OutType* out_ptr, int64_t assigned_rows, int64_t rows, int64_t cols, |
|
int relu) { |
|
constexpr int kShiftAmount = |
|
TypeOfProduct<WeightType, RhsType>::type::kMantissaBits - |
|
OutType::kMantissaBits; |
|
static_assert(kShiftAmount >= 0, |
|
"Result must have fewer mantissa bits than product"); |
|
const RhsType* rhs_ptrs[5]; |
|
for (int i = 0; i < 5; ++i) rhs_ptrs[i] = rhs_ptr + i * cols; |
|
|
|
OutType* out_ptrs[5]; |
|
for (int i = 0; i < 5; ++i) out_ptrs[i] = out_ptr + i * rows; |
|
|
|
for (int reduced_row = 0; reduced_row < assigned_rows; ++reduced_row) { |
|
|
|
|
|
__m128i bias = _mm_load_si128(reinterpret_cast<__m128i const*>(bias_ptr)); |
|
__m256i biases = _mm256_set_m128i(bias, bias); |
|
bias_ptr += 4; |
|
|
|
biases = _mm256_permute4x64_epi64(biases, 0xb4); |
|
|
|
biases = _mm256_unpacklo_epi32(biases, biases); |
|
|
|
__m256i sum_0 = _mm256_add_epi32(biases, biases); |
|
__m256i sum_1 = sum_0; |
|
__m256i sum_2 = sum_0; |
|
__m256i sum_3 = sum_0; |
|
__m256i sum_4 = sum_0; |
|
|
|
int reduced_col_count = *nnz_per_row; |
|
++nnz_per_row; |
|
for (int c = 0; c < reduced_col_count; ++c) { |
|
int col_delta = *col_deltas_bytes++ / sizeof(RhsType); |
|
for (int k = 0; k < 5; ++k) rhs_ptrs[k] += col_delta; |
|
|
|
|
|
__m128i rhs_64 = |
|
_mm_loadl_epi64(reinterpret_cast<__m128i const*>(rhs_ptrs[0])); |
|
|
|
__m256i weights = |
|
_mm256_load_si256(reinterpret_cast<__m256i const*>(weights_ptr)); |
|
|
|
|
|
__m256i rhs = _mm256_broadcastq_epi64(rhs_64); |
|
weights_ptr += 16; |
|
|
|
|
|
sum_0 = _mm256_add_epi32(sum_0, _mm256_madd_epi16(weights, rhs)); |
|
rhs_64 = _mm_loadl_epi64(reinterpret_cast<__m128i const*>(rhs_ptrs[1])); |
|
rhs = _mm256_broadcastq_epi64(rhs_64); |
|
sum_1 = _mm256_add_epi32(sum_1, _mm256_madd_epi16(weights, rhs)); |
|
rhs_64 = _mm_loadl_epi64(reinterpret_cast<__m128i const*>(rhs_ptrs[2])); |
|
rhs = _mm256_broadcastq_epi64(rhs_64); |
|
sum_2 = _mm256_add_epi32(sum_2, _mm256_madd_epi16(weights, rhs)); |
|
rhs_64 = _mm_loadl_epi64(reinterpret_cast<__m128i const*>(rhs_ptrs[3])); |
|
rhs = _mm256_broadcastq_epi64(rhs_64); |
|
sum_3 = _mm256_add_epi32(sum_3, _mm256_madd_epi16(weights, rhs)); |
|
rhs_64 = _mm_loadl_epi64(reinterpret_cast<__m128i const*>(rhs_ptrs[4])); |
|
rhs = _mm256_broadcastq_epi64(rhs_64); |
|
sum_4 = _mm256_add_epi32(sum_4, _mm256_madd_epi16(weights, rhs)); |
|
} |
|
static_assert( |
|
IsFixed16Type<OutType>::value || IsFixed32Type<OutType>::value, |
|
"AVX2 kernel only supports fixed16 and fixed32 types"); |
|
|
|
|
|
if (IsFixed32Type<OutType>::value) { |
|
Extract4xint32(relu, kShiftAmount, sum_0, |
|
reinterpret_cast<int32_t**>(&out_ptrs[0])); |
|
Extract4xint32(relu, kShiftAmount, sum_1, |
|
reinterpret_cast<int32_t**>(&out_ptrs[1])); |
|
Extract4xint32(relu, kShiftAmount, sum_2, |
|
reinterpret_cast<int32_t**>(&out_ptrs[2])); |
|
Extract4xint32(relu, kShiftAmount, sum_3, |
|
reinterpret_cast<int32_t**>(&out_ptrs[3])); |
|
Extract4xint32(relu, kShiftAmount, sum_4, |
|
reinterpret_cast<int32_t**>(&out_ptrs[4])); |
|
} else { |
|
Extract4xint16(relu, kShiftAmount, sum_0, |
|
reinterpret_cast<int16_t**>(&out_ptrs[0])); |
|
Extract4xint16(relu, kShiftAmount, sum_1, |
|
reinterpret_cast<int16_t**>(&out_ptrs[1])); |
|
Extract4xint16(relu, kShiftAmount, sum_2, |
|
reinterpret_cast<int16_t**>(&out_ptrs[2])); |
|
Extract4xint16(relu, kShiftAmount, sum_3, |
|
reinterpret_cast<int16_t**>(&out_ptrs[3])); |
|
Extract4xint16(relu, kShiftAmount, sum_4, |
|
reinterpret_cast<int16_t**>(&out_ptrs[4])); |
|
} |
|
} |
|
} |
|
|
|
|
|
template <int InputMantissaBits, int StateMantissaBits, bool SplitGates> |
|
inline __m256i GRUGateSigmoid(const void* gate_ptr, const void* gate_other_ptr, |
|
const __m256i& input, |
|
const int32_t* sigmoid_table) { |
|
__m256i gate = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(gate_ptr)); |
|
if (SplitGates) { |
|
__m256i other = |
|
_mm256_loadu_si256(reinterpret_cast<const __m256i*>(gate_other_ptr)); |
|
gate = _mm256_add_epi32(gate, other); |
|
} |
|
gate = _mm256_add_epi32(gate, input); |
|
|
|
return csrblocksparse::fixed32_sigmoid_fixed16<InputMantissaBits, |
|
StateMantissaBits>( |
|
sigmoid_table, gate); |
|
} |
|
|
|
|
|
template <int InputMantissaBits, int StateMantissaBits, bool SplitGates = false> |
|
inline __m256i GRUGateState(const __m256i& cell, const __m256i& reset, |
|
const __m256i& update, |
|
const __m256i& rounding_offset, |
|
const void* gate_ptr, const void* gate_other_ptr, |
|
const void* gru_h_ptr, const int32_t* tanh_table) { |
|
|
|
|
|
__m256i gru = _mm256_loadu_si256(reinterpret_cast<__m256i const*>(gate_ptr)); |
|
if (SplitGates) { |
|
__m256i other_gru = |
|
_mm256_loadu_si256(reinterpret_cast<__m256i const*>(gate_other_ptr)); |
|
gru = _mm256_add_epi32(gru, other_gru); |
|
} |
|
|
|
__m256i gru_lo = _mm256_mul_epi32(gru, reset); |
|
|
|
gru = _mm256_shuffle_epi32(gru, 0xb1); |
|
__m256i gru_hi = _mm256_mul_epi32(gru, _mm256_shuffle_epi32(reset, 0xb1)); |
|
|
|
|
|
|
|
|
|
|
|
gru_lo = _mm256_srli_epi64(gru_lo, StateMantissaBits); |
|
|
|
|
|
gru_hi = _mm256_slli_epi64(gru_hi, 32 - StateMantissaBits); |
|
|
|
gru = _mm256_blend_epi32(gru_lo, gru_hi, 0xaa); |
|
gru = _mm256_add_epi32(cell, gru); |
|
|
|
|
|
|
|
__m256i hbar = |
|
csrblocksparse::fixed32_tanh_fixed16<InputMantissaBits, |
|
StateMantissaBits>(tanh_table, gru); |
|
|
|
gru = _mm256_cvtepi16_epi32( |
|
_mm_load_si128(reinterpret_cast<__m128i const*>(gru_h_ptr))); |
|
gru = _mm256_sub_epi32(gru, hbar); |
|
|
|
|
|
|
|
|
|
|
|
|
|
gru = _mm256_madd_epi16(gru, update); |
|
|
|
|
|
gru = _mm256_add_epi32(gru, rounding_offset); |
|
gru = _mm256_srai_epi32(gru, StateMantissaBits); |
|
return _mm256_add_epi32(gru, hbar); |
|
} |
|
|
|
template <typename Type> |
|
typename std::enable_if<IsFixed32Type<Type>::value>::type SumVectors( |
|
int start, int end, const Type* add1, const Type* add2, Type* result) { |
|
constexpr int kSIMDWidth = 8; |
|
for (int i = start; i < end; i += kSIMDWidth) { |
|
__m256i data1 = |
|
_mm256_load_si256(reinterpret_cast<__m256i const*>(add1 + i)); |
|
__m256i data2 = |
|
_mm256_load_si256(reinterpret_cast<__m256i const*>(add2 + i)); |
|
data1 = _mm256_add_epi32(data1, data2); |
|
_mm256_store_si256(reinterpret_cast<__m256i*>(result + i), data1); |
|
} |
|
} |
|
|
|
template <typename Type> |
|
typename std::enable_if<IsFixed16Type<Type>::value>::type SumVectors( |
|
int start, int end, const Type* add1, const Type* add2, Type* result) { |
|
constexpr int kSIMDWidth = 16; |
|
for (int i = start; i < end; i += kSIMDWidth) { |
|
__m256i data1 = |
|
_mm256_load_si256(reinterpret_cast<__m256i const*>(add1 + i)); |
|
__m256i data2 = |
|
_mm256_load_si256(reinterpret_cast<__m256i const*>(add2 + i)); |
|
data1 = _mm256_add_epi16(data1, data2); |
|
_mm256_store_si256(reinterpret_cast<__m256i*>(result + i), data1); |
|
} |
|
} |
|
|
|
#endif |
|
|
|
} |
|
} |
|
|
|
#undef LABEL_COL_LOOP |
|
#undef LABEL_ROW_LOOP |
|
#undef LABEL_SKIP_COL_LOOP |
|
#undef LABEL_TOP_LOOP |
|
|
|
#endif |
|
|
|
#endif |
|
|