|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#ifndef LYRA_CODEC_SPARSE_MATMUL_COMPUTE_KERNELS_GENERIC_H_ |
|
#define LYRA_CODEC_SPARSE_MATMUL_COMPUTE_KERNELS_GENERIC_H_ |
|
|
|
#include <algorithm> |
|
#include <type_traits> |
|
|
|
#include "sparse_matmul/numerics/fixed_types.h" |
|
#include "sparse_matmul/numerics/float16_types.h" |
|
#include "sparse_matmul/numerics/type_utils.h" |
|
|
|
|
|
|
|
#if defined __aarch64__ |
|
#include "sparse_matmul/compute/kernels_arm.h" |
|
#elif defined __AVX__ |
|
#include "sparse_matmul/compute/kernels_avx.h" |
|
#else |
|
|
|
template <typename WeightType, typename RhsType, typename OutType> |
|
struct ShouldEnableGenericSpMV_4x4 : std::true_type {}; |
|
template <typename WeightType, typename RhsType, typename OutType> |
|
struct ShouldEnableGenericSpMM5_4x4 : std::true_type {}; |
|
template <typename WeightType, typename RhsType, typename OutType> |
|
struct ShouldEnableGenericSpMV_1x1 : std::true_type {}; |
|
template <typename WeightType, typename RhsType, typename OutType> |
|
struct ShouldEnableGenericSpMM5_1x1 : std::true_type {}; |
|
template <typename Type> |
|
struct ShouldEnableGenericAdd : std::true_type {}; |
|
#endif |
|
|
|
namespace csrblocksparse { |
|
namespace detail { |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template <typename WeightType, typename RhsType, typename OutType> |
|
typename std::enable_if< |
|
ShouldEnableGenericSpMV_4x4<WeightType, RhsType, OutType>::value>::type |
|
SpMV_4x4(const WeightType* weights_ptr, const int16_t* col_deltas_bytes, |
|
const int32_t* nnz_per_row, const RhsType* rhs_ptr, |
|
const typename TypeOfProduct<WeightType, RhsType>::type* bias_ptr, |
|
OutType* out_ptr, int64_t assigned_rows, |
|
int64_t rows , |
|
int64_t cols , int relu) { |
|
for (int reduced_row = 0; reduced_row < assigned_rows; ++reduced_row) { |
|
float accumulators[4]; |
|
|
|
for (int i = 0; i < 4; ++i) |
|
accumulators[i] = 4.f * static_cast<float>(*bias_ptr++); |
|
|
|
int reduced_col_count = *nnz_per_row++; |
|
for (int c = 0; c < reduced_col_count; ++c) { |
|
int col_delta = *col_deltas_bytes++ / sizeof(RhsType); |
|
rhs_ptr += col_delta; |
|
|
|
|
|
for (int i = 0; i < 4; ++i) { |
|
for (int j = 0; j < 4; ++j) { |
|
accumulators[i] += static_cast<float>(*weights_ptr++) * |
|
static_cast<float>(rhs_ptr[j]); |
|
} |
|
} |
|
} |
|
|
|
for (int i = 0; i < 4; ++i) |
|
*out_ptr++ = static_cast<OutType>(relu ? std::max(accumulators[i], 0.f) |
|
: accumulators[i]); |
|
} |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template <typename WeightType, typename RhsType, typename OutType> |
|
typename std::enable_if< |
|
ShouldEnableGenericSpMM5_4x4<WeightType, RhsType, OutType>::value>::type |
|
SpMM5_4x4(const WeightType* weights_ptr, const int16_t* col_deltas_bytes, |
|
const int32_t* nnz_per_row, const RhsType* rhs_ptr, |
|
const typename TypeOfProduct<WeightType, RhsType>::type* bias_ptr, |
|
OutType* out_ptr, int64_t assigned_rows, int64_t rows, int64_t cols, |
|
int relu) { |
|
const RhsType* rhs_ptrs[5]; |
|
for (int i = 0; i < 5; ++i) rhs_ptrs[i] = rhs_ptr + i * cols; |
|
|
|
OutType* out_ptrs[5]; |
|
for (int i = 0; i < 5; ++i) out_ptrs[i] = out_ptr + i * rows; |
|
|
|
for (int reduced_row = 0; reduced_row < assigned_rows; ++reduced_row) { |
|
float accumulators[4][5]; |
|
|
|
for (int i = 0; i < 4; ++i) { |
|
for (int k = 0; k < 5; ++k) { |
|
accumulators[i][k] = 4.f * static_cast<float>(*bias_ptr); |
|
} |
|
++bias_ptr; |
|
} |
|
|
|
int reduced_col_count = *nnz_per_row++; |
|
for (int c = 0; c < reduced_col_count; ++c) { |
|
int col_delta = *col_deltas_bytes++ / sizeof(RhsType); |
|
for (int k = 0; k < 5; ++k) rhs_ptrs[k] += col_delta; |
|
|
|
|
|
for (int i = 0; i < 4; ++i) { |
|
for (int j = 0; j < 4; ++j) { |
|
for (int k = 0; k < 5; ++k) { |
|
accumulators[i][k] += static_cast<float>(*weights_ptr) * |
|
static_cast<float>(rhs_ptrs[k][j]); |
|
} |
|
weights_ptr++; |
|
} |
|
} |
|
} |
|
|
|
for (int k = 0; k < 5; ++k) { |
|
for (int i = 0; i < 4; ++i) { |
|
out_ptrs[k][0] = static_cast<OutType>( |
|
relu ? std::max(accumulators[i][k], 0.f) : accumulators[i][k]); |
|
out_ptrs[k]++; |
|
} |
|
} |
|
} |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template <typename WeightType, typename RhsType, typename OutType> |
|
typename std::enable_if< |
|
ShouldEnableGenericSpMV_1x1<WeightType, RhsType, OutType>::value>::type |
|
SpMV_1x1(const WeightType* weights_ptr, const int16_t* col_deltas_bytes, |
|
const int32_t* nnz_per_row, const RhsType* rhs_ptr, |
|
const typename TypeOfProduct<WeightType, RhsType>::type* bias_ptr, |
|
OutType* out_ptr, int64_t assigned_rows, |
|
int64_t rows , |
|
int64_t cols , int relu) { |
|
for (int row = 0; row < assigned_rows; ++row) { |
|
|
|
float accumulator = 4.f * static_cast<float>(*bias_ptr++); |
|
|
|
int col_count = *nnz_per_row++; |
|
for (int c = 0; c < col_count; ++c) { |
|
int col_delta = *col_deltas_bytes++ / sizeof(RhsType); |
|
rhs_ptr += col_delta; |
|
|
|
accumulator += |
|
static_cast<float>(*weights_ptr++) * static_cast<float>(*rhs_ptr); |
|
} |
|
|
|
*out_ptr++ = |
|
static_cast<OutType>(relu ? std::max(accumulator, 0.f) : accumulator); |
|
} |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template <typename WeightType, typename RhsType, typename OutType> |
|
typename std::enable_if< |
|
ShouldEnableGenericSpMM5_1x1<WeightType, RhsType, OutType>::value>::type |
|
SpMM5_1x1(const WeightType* weights_ptr, const int16_t* col_deltas_bytes, |
|
const int32_t* nnz_per_row, const RhsType* rhs_ptr, |
|
const typename TypeOfProduct<WeightType, RhsType>::type* bias_ptr, |
|
OutType* out_ptr, int64_t assigned_rows, int64_t rows, int64_t cols, |
|
int relu) { |
|
const RhsType* rhs_ptrs[5]; |
|
for (int i = 0; i < 5; ++i) rhs_ptrs[i] = rhs_ptr + i * cols; |
|
|
|
OutType* out_ptrs[5]; |
|
for (int i = 0; i < 5; ++i) out_ptrs[i] = out_ptr + i * rows; |
|
|
|
for (int row = 0; row < assigned_rows; ++row) { |
|
|
|
float accumulator[5]; |
|
for (int i = 0; i < 5; ++i) |
|
accumulator[i] = 4.f * static_cast<float>(*bias_ptr); |
|
|
|
++bias_ptr; |
|
|
|
int col_count = *nnz_per_row++; |
|
for (int c = 0; c < col_count; ++c) { |
|
int col_delta = *col_deltas_bytes++ / sizeof(RhsType); |
|
for (int i = 0; i < 5; ++i) { |
|
rhs_ptrs[i] += col_delta; |
|
accumulator[i] += static_cast<float>(*weights_ptr) * |
|
static_cast<float>(rhs_ptrs[i][0]); |
|
} |
|
weights_ptr++; |
|
} |
|
|
|
for (int i = 0; i < 5; ++i) { |
|
out_ptrs[i][0] = static_cast<OutType>(relu ? std::max(accumulator[i], 0.f) |
|
: accumulator[i]); |
|
out_ptrs[i]++; |
|
} |
|
} |
|
} |
|
|
|
template <typename Type> |
|
typename std::enable_if<ShouldEnableGenericAdd<Type>::value>::type SumVectors( |
|
int start, int end, const Type* add1, const Type* add2, Type* result) { |
|
LOG_FIRST_N(WARNING, 1) << "SumVectors: using generic kernel!"; |
|
for (int i = start; i < end; ++i) { |
|
Type sum = static_cast<Type>(static_cast<float>(add1[i]) + |
|
static_cast<float>(add2[i])); |
|
result[i] = sum; |
|
} |
|
} |
|
|
|
} |
|
} |
|
|
|
#undef LABEL_COL_LOOP |
|
#undef LABEL_ROW_LOOP |
|
#undef LABEL_SKIP_COL_LOOP |
|
#undef LABEL_TOP_LOOP |
|
|
|
#endif |
|
|