|
#pragma once |
|
|
|
|
|
|
|
|
|
#include <ATen/cpu/vec/intrinsics.h> |
|
#include <ATen/cpu/vec/vec_base.h> |
|
#include <c10/util/irange.h> |
|
#if defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER) |
|
#include <sleef.h> |
|
#endif |
|
|
|
namespace at { |
|
namespace vec { |
|
|
|
inline namespace CPU_CAPABILITY { |
|
|
|
#if defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER) |
|
|
|
template <> class Vectorized<float> { |
|
private: |
|
static constexpr __m512i zero_vec {0, 0, 0, 0, 0, 0, 0, 0}; |
|
public: |
|
__m512 values; |
|
using value_type = float; |
|
using size_type = int; |
|
static constexpr size_type size() { |
|
return 16; |
|
} |
|
Vectorized() {} |
|
Vectorized(__m512 v) : values(v) {} |
|
Vectorized(float val) { |
|
values = _mm512_set1_ps(val); |
|
} |
|
Vectorized(float val1, float val2, float val3, float val4, |
|
float val5, float val6, float val7, float val8, |
|
float val9, float val10, float val11, float val12, |
|
float val13, float val14, float val15, float val16) { |
|
values = _mm512_setr_ps(val1, val2, val3, val4, val5, val6, val7, val8, |
|
val9, val10, val11, val12, val13, val14, val15, val16); |
|
} |
|
operator __m512() const { |
|
return values; |
|
} |
|
template <int64_t mask> |
|
static Vectorized<float> blend(const Vectorized<float>& a, const Vectorized<float>& b) { |
|
return _mm512_mask_blend_ps(mask, a.values, b.values); |
|
} |
|
static Vectorized<float> blendv(const Vectorized<float>& a, const Vectorized<float>& b, |
|
const Vectorized<float>& mask) { |
|
auto all_ones = _mm512_set1_epi32(0xFFFFFFFF); |
|
auto mmask = _mm512_cmp_epi32_mask(_mm512_castps_si512(mask.values), all_ones, _MM_CMPINT_EQ); |
|
return _mm512_mask_blend_ps(mmask, a.values, b.values); |
|
} |
|
template<typename step_t> |
|
static Vectorized<float> arange(float base = 0.f, step_t step = static_cast<step_t>(1)) { |
|
return Vectorized<float>( |
|
base, base + step, base + 2 * step, base + 3 * step, |
|
base + 4 * step, base + 5 * step, base + 6 * step, base + 7 * step, |
|
base + 8 * step, base + 9 * step, base + 10 * step, base + 11 * step, |
|
base + 12 * step, base + 13 * step, base + 14 * step, base + 15 * step); |
|
} |
|
static Vectorized<float> set(const Vectorized<float>& a, const Vectorized<float>& b, |
|
int64_t count = size()) { |
|
switch (count) { |
|
case 0: |
|
return a; |
|
case 1: |
|
return blend<1>(a, b); |
|
case 2: |
|
return blend<3>(a, b); |
|
case 3: |
|
return blend<7>(a, b); |
|
case 4: |
|
return blend<15>(a, b); |
|
case 5: |
|
return blend<31>(a, b); |
|
case 6: |
|
return blend<63>(a, b); |
|
case 7: |
|
return blend<127>(a, b); |
|
case 8: |
|
return blend<255>(a, b); |
|
case 9: |
|
return blend<511>(a, b); |
|
case 10: |
|
return blend<1023>(a, b); |
|
case 11: |
|
return blend<2047>(a, b); |
|
case 12: |
|
return blend<4095>(a, b); |
|
case 13: |
|
return blend<8191>(a, b); |
|
case 14: |
|
return blend<16383>(a, b); |
|
case 15: |
|
return blend<32767>(a, b); |
|
} |
|
return b; |
|
} |
|
static Vectorized<float> loadu(const void* ptr, int64_t count = size()) { |
|
if (count == size()) |
|
return _mm512_loadu_ps(reinterpret_cast<const float*>(ptr)); |
|
__at_align__ float tmp_values[size()]; |
|
|
|
|
|
|
|
for (const auto i : c10::irange(size())) { |
|
tmp_values[i] = 0.0; |
|
} |
|
std::memcpy( |
|
tmp_values, reinterpret_cast<const float*>(ptr), count * sizeof(float)); |
|
return _mm512_loadu_ps(tmp_values); |
|
} |
|
void store(void* ptr, int64_t count = size()) const { |
|
if (count == size()) { |
|
_mm512_storeu_ps(reinterpret_cast<float*>(ptr), values); |
|
} else if (count > 0) { |
|
float tmp_values[size()]; |
|
_mm512_storeu_ps(reinterpret_cast<float*>(tmp_values), values); |
|
std::memcpy(ptr, tmp_values, count * sizeof(float)); |
|
} |
|
} |
|
const float& operator[](int idx) const = delete; |
|
float& operator[](int idx) = delete; |
|
int zero_mask() const { |
|
|
|
__mmask16 cmp = _mm512_cmp_ps_mask(values, _mm512_set1_ps(0.0), _CMP_EQ_OQ); |
|
return static_cast<int32_t>(cmp); |
|
} |
|
Vectorized<float> isnan() const { |
|
auto mask = _mm512_cmp_ps_mask(values, _mm512_set1_ps(0.0), _CMP_UNORD_Q); |
|
return _mm512_castsi512_ps(_mm512_mask_set1_epi32(zero_vec, mask, |
|
0xFFFFFFFF)); |
|
} |
|
Vectorized<float> map(float (*const f)(float)) const { |
|
__at_align__ float tmp[size()]; |
|
store(tmp); |
|
for (const auto i : c10::irange(size())) { |
|
tmp[i] = f(tmp[i]); |
|
} |
|
return loadu(tmp); |
|
} |
|
Vectorized<float> abs() const { |
|
auto mask = _mm512_set1_ps(-0.f); |
|
return _mm512_andnot_ps(mask, values); |
|
} |
|
Vectorized<float> angle() const { |
|
__m512 zero_vec = _mm512_set1_ps(0.f); |
|
const auto nan_vec = _mm512_set1_ps(NAN); |
|
const auto not_nan_mask = _mm512_cmp_ps_mask(values, values, _CMP_EQ_OQ); |
|
const auto not_nan_vec = _mm512_mask_set1_epi32(_mm512_castps_si512(zero_vec), |
|
not_nan_mask, 0xFFFFFFFF); |
|
const auto nan_mask = _mm512_cmp_ps_mask(_mm512_castsi512_ps(not_nan_vec), |
|
zero_vec, _CMP_EQ_OQ); |
|
const auto pi = _mm512_set1_ps(c10::pi<double>); |
|
|
|
const auto neg_mask = _mm512_cmp_ps_mask(values, zero_vec, _CMP_LT_OQ); |
|
auto angle = _mm512_mask_blend_ps(neg_mask, zero_vec, pi); |
|
angle = _mm512_mask_blend_ps(nan_mask, angle, nan_vec); |
|
return angle; |
|
} |
|
Vectorized<float> real() const { |
|
return *this; |
|
} |
|
Vectorized<float> imag() const { |
|
return _mm512_set1_ps(0); |
|
} |
|
Vectorized<float> conj() const { |
|
return *this; |
|
} |
|
Vectorized<float> acos() const { |
|
return Vectorized<float>(Sleef_acosf16_u10(values)); |
|
} |
|
Vectorized<float> asin() const { |
|
return Vectorized<float>(Sleef_asinf16_u10(values)); |
|
} |
|
Vectorized<float> atan() const { |
|
return Vectorized<float>(Sleef_atanf16_u10(values)); |
|
} |
|
Vectorized<float> atan2(const Vectorized<float> &b) const { |
|
return Vectorized<float>(Sleef_atan2f16_u10(values, b)); |
|
} |
|
Vectorized<float> copysign(const Vectorized<float> &sign) const { |
|
return Vectorized<float>(Sleef_copysignf16(values, sign)); |
|
} |
|
Vectorized<float> erf() const { |
|
return Vectorized<float>(Sleef_erff16_u10(values)); |
|
} |
|
Vectorized<float> erfc() const { |
|
return Vectorized<float>(Sleef_erfcf16_u15(values)); |
|
} |
|
Vectorized<float> erfinv() const { |
|
return map(calc_erfinv); |
|
} |
|
Vectorized<float> exp() const { |
|
return Vectorized<float>(Sleef_expf16_u10(values)); |
|
} |
|
Vectorized<float> expm1() const { |
|
return Vectorized<float>(Sleef_expm1f16_u10(values)); |
|
} |
|
Vectorized<float> fmod(const Vectorized<float>& q) const { |
|
return Vectorized<float>(Sleef_fmodf16(values, q)); |
|
} |
|
Vectorized<float> log() const { |
|
return Vectorized<float>(Sleef_logf16_u10(values)); |
|
} |
|
Vectorized<float> log2() const { |
|
return Vectorized<float>(Sleef_log2f16_u10(values)); |
|
} |
|
Vectorized<float> log10() const { |
|
return Vectorized<float>(Sleef_log10f16_u10(values)); |
|
} |
|
Vectorized<float> log1p() const { |
|
return Vectorized<float>(Sleef_log1pf16_u10(values)); |
|
} |
|
Vectorized<float> frac() const; |
|
Vectorized<float> sin() const { |
|
return Vectorized<float>(Sleef_sinf16_u10(values)); |
|
} |
|
Vectorized<float> sinh() const { |
|
return Vectorized<float>(Sleef_sinhf16_u10(values)); |
|
} |
|
Vectorized<float> cos() const { |
|
return Vectorized<float>(Sleef_cosf16_u10(values)); |
|
} |
|
Vectorized<float> cosh() const { |
|
return Vectorized<float>(Sleef_coshf16_u10(values)); |
|
} |
|
Vectorized<float> ceil() const { |
|
return _mm512_ceil_ps(values); |
|
} |
|
Vectorized<float> floor() const { |
|
return _mm512_floor_ps(values); |
|
} |
|
Vectorized<float> hypot(const Vectorized<float> &b) const { |
|
return Vectorized<float>(Sleef_hypotf16_u05(values, b)); |
|
} |
|
Vectorized<float> i0() const { |
|
return map(calc_i0); |
|
} |
|
Vectorized<float> i0e() const { |
|
return map(calc_i0e); |
|
} |
|
Vectorized<float> igamma(const Vectorized<float> &x) const { |
|
__at_align__ float tmp[size()]; |
|
__at_align__ float tmp_x[size()]; |
|
store(tmp); |
|
x.store(tmp_x); |
|
for (const auto i : c10::irange(size())) { |
|
tmp[i] = calc_igamma(tmp[i], tmp_x[i]); |
|
} |
|
return loadu(tmp); |
|
} |
|
Vectorized<float> igammac(const Vectorized<float> &x) const { |
|
__at_align__ float tmp[size()]; |
|
__at_align__ float tmp_x[size()]; |
|
store(tmp); |
|
x.store(tmp_x); |
|
for (const auto i : c10::irange(size())) { |
|
tmp[i] = calc_igammac(tmp[i], tmp_x[i]); |
|
} |
|
return loadu(tmp); |
|
} |
|
Vectorized<float> neg() const { |
|
return _mm512_xor_ps(_mm512_set1_ps(-0.f), values); |
|
} |
|
Vectorized<float> nextafter(const Vectorized<float> &b) const { |
|
return Vectorized<float>(Sleef_nextafterf16(values, b)); |
|
} |
|
Vectorized<float> round() const { |
|
return _mm512_roundscale_ps(values, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); |
|
} |
|
Vectorized<float> tan() const { |
|
return Vectorized<float>(Sleef_tanf16_u10(values)); |
|
} |
|
Vectorized<float> tanh() const { |
|
return Vectorized<float>(Sleef_tanhf16_u10(values)); |
|
} |
|
Vectorized<float> trunc() const { |
|
return _mm512_roundscale_ps(values, (_MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC)); |
|
} |
|
Vectorized<float> lgamma() const { |
|
return Vectorized<float>(Sleef_lgammaf16_u10(values)); |
|
} |
|
Vectorized<float> sqrt() const { |
|
return _mm512_sqrt_ps(values); |
|
} |
|
Vectorized<float> reciprocal() const { |
|
return _mm512_div_ps(_mm512_set1_ps(1), values); |
|
} |
|
Vectorized<float> rsqrt() const { |
|
return _mm512_div_ps(_mm512_set1_ps(1), _mm512_sqrt_ps(values)); |
|
} |
|
Vectorized<float> pow(const Vectorized<float> &b) const { |
|
return Vectorized<float>(Sleef_powf16_u10(values, b)); |
|
} |
|
|
|
|
|
|
|
Vectorized<float> operator==(const Vectorized<float>& other) const { |
|
auto mask = _mm512_cmp_ps_mask(values, other.values, _CMP_EQ_OQ); |
|
return _mm512_castsi512_ps(_mm512_mask_set1_epi32(zero_vec, mask, |
|
0xFFFFFFFF)); |
|
} |
|
|
|
Vectorized<float> operator!=(const Vectorized<float>& other) const { |
|
auto mask = _mm512_cmp_ps_mask(values, other.values, _CMP_NEQ_OQ); |
|
return _mm512_castsi512_ps(_mm512_mask_set1_epi32(zero_vec, mask, |
|
0xFFFFFFFF)); |
|
} |
|
|
|
Vectorized<float> operator<(const Vectorized<float>& other) const { |
|
auto mask = _mm512_cmp_ps_mask(values, other.values, _CMP_LT_OQ); |
|
return _mm512_castsi512_ps(_mm512_mask_set1_epi32(zero_vec, mask, |
|
0xFFFFFFFF)); |
|
} |
|
|
|
Vectorized<float> operator<=(const Vectorized<float>& other) const { |
|
auto mask = _mm512_cmp_ps_mask(values, other.values, _CMP_LE_OQ); |
|
return _mm512_castsi512_ps(_mm512_mask_set1_epi32(zero_vec, mask, |
|
0xFFFFFFFF)); |
|
} |
|
|
|
Vectorized<float> operator>(const Vectorized<float>& other) const { |
|
auto mask = _mm512_cmp_ps_mask(values, other.values, _CMP_GT_OQ); |
|
return _mm512_castsi512_ps(_mm512_mask_set1_epi32(zero_vec, mask, |
|
0xFFFFFFFF)); |
|
} |
|
|
|
Vectorized<float> operator>=(const Vectorized<float>& other) const { |
|
auto mask = _mm512_cmp_ps_mask(values, other.values, _CMP_GE_OQ); |
|
return _mm512_castsi512_ps(_mm512_mask_set1_epi32(zero_vec, mask, |
|
0xFFFFFFFF)); |
|
} |
|
|
|
Vectorized<float> eq(const Vectorized<float>& other) const; |
|
Vectorized<float> ne(const Vectorized<float>& other) const; |
|
Vectorized<float> gt(const Vectorized<float>& other) const; |
|
Vectorized<float> ge(const Vectorized<float>& other) const; |
|
Vectorized<float> lt(const Vectorized<float>& other) const; |
|
Vectorized<float> le(const Vectorized<float>& other) const; |
|
}; |
|
|
|
template <> |
|
Vectorized<float> inline operator+(const Vectorized<float>& a, const Vectorized<float>& b) { |
|
return _mm512_add_ps(a, b); |
|
} |
|
|
|
template <> |
|
Vectorized<float> inline operator-(const Vectorized<float>& a, const Vectorized<float>& b) { |
|
return _mm512_sub_ps(a, b); |
|
} |
|
|
|
template <> |
|
Vectorized<float> inline operator*(const Vectorized<float>& a, const Vectorized<float>& b) { |
|
return _mm512_mul_ps(a, b); |
|
} |
|
|
|
template <> |
|
Vectorized<float> inline operator/(const Vectorized<float>& a, const Vectorized<float>& b) { |
|
return _mm512_div_ps(a, b); |
|
} |
|
|
|
|
|
inline Vectorized<float> Vectorized<float>::frac() const { |
|
return *this - this->trunc(); |
|
} |
|
|
|
|
|
|
|
template <> |
|
Vectorized<float> inline maximum(const Vectorized<float>& a, const Vectorized<float>& b) { |
|
auto zero_vec = _mm512_set1_epi32(0); |
|
auto max = _mm512_max_ps(a, b); |
|
auto isnan_mask = _mm512_cmp_ps_mask(a, b, _CMP_UNORD_Q); |
|
auto isnan = _mm512_castsi512_ps(_mm512_mask_set1_epi32(zero_vec, isnan_mask, |
|
0xFFFFFFFF)); |
|
|
|
return _mm512_or_ps(max, isnan); |
|
} |
|
|
|
|
|
|
|
template <> |
|
Vectorized<float> inline minimum(const Vectorized<float>& a, const Vectorized<float>& b) { |
|
auto zero_vec = _mm512_set1_epi32(0); |
|
auto min = _mm512_min_ps(a, b); |
|
auto isnan_mask = _mm512_cmp_ps_mask(a, b, _CMP_UNORD_Q); |
|
auto isnan = _mm512_castsi512_ps(_mm512_mask_set1_epi32(zero_vec, isnan_mask, |
|
0xFFFFFFFF)); |
|
|
|
return _mm512_or_ps(min, isnan); |
|
} |
|
|
|
template <> |
|
Vectorized<float> inline clamp(const Vectorized<float>& a, const Vectorized<float>& min, const Vectorized<float>& max) { |
|
return _mm512_min_ps(max, _mm512_max_ps(min, a)); |
|
} |
|
|
|
template <> |
|
Vectorized<float> inline clamp_max(const Vectorized<float>& a, const Vectorized<float>& max) { |
|
return _mm512_min_ps(max, a); |
|
} |
|
|
|
template <> |
|
Vectorized<float> inline clamp_min(const Vectorized<float>& a, const Vectorized<float>& min) { |
|
return _mm512_max_ps(min, a); |
|
} |
|
|
|
template <> |
|
Vectorized<float> inline operator&(const Vectorized<float>& a, const Vectorized<float>& b) { |
|
return _mm512_and_ps(a, b); |
|
} |
|
|
|
template <> |
|
Vectorized<float> inline operator|(const Vectorized<float>& a, const Vectorized<float>& b) { |
|
return _mm512_or_ps(a, b); |
|
} |
|
|
|
template <> |
|
Vectorized<float> inline operator^(const Vectorized<float>& a, const Vectorized<float>& b) { |
|
return _mm512_xor_ps(a, b); |
|
} |
|
|
|
inline Vectorized<float> Vectorized<float>::eq(const Vectorized<float>& other) const { |
|
return (*this == other) & Vectorized<float>(1.0f); |
|
} |
|
|
|
inline Vectorized<float> Vectorized<float>::ne(const Vectorized<float>& other) const { |
|
return (*this != other) & Vectorized<float>(1.0f); |
|
} |
|
|
|
inline Vectorized<float> Vectorized<float>::gt(const Vectorized<float>& other) const { |
|
return (*this > other) & Vectorized<float>(1.0f); |
|
} |
|
|
|
inline Vectorized<float> Vectorized<float>::ge(const Vectorized<float>& other) const { |
|
return (*this >= other) & Vectorized<float>(1.0f); |
|
} |
|
|
|
inline Vectorized<float> Vectorized<float>::lt(const Vectorized<float>& other) const { |
|
return (*this < other) & Vectorized<float>(1.0f); |
|
} |
|
|
|
inline Vectorized<float> Vectorized<float>::le(const Vectorized<float>& other) const { |
|
return (*this <= other) & Vectorized<float>(1.0f); |
|
} |
|
|
|
template <> |
|
inline void convert(const float* src, float* dst, int64_t n) { |
|
int64_t i; |
|
#pragma unroll |
|
for (i = 0; i <= (n - Vectorized<float>::size()); i += Vectorized<float>::size()) { |
|
_mm512_storeu_ps(dst + i, _mm512_loadu_ps(src + i)); |
|
} |
|
#pragma unroll |
|
for (; i < n; i++) { |
|
dst[i] = src[i]; |
|
} |
|
} |
|
|
|
template <> |
|
Vectorized<float> inline fmadd(const Vectorized<float>& a, const Vectorized<float>& b, const Vectorized<float>& c) { |
|
return _mm512_fmadd_ps(a, b, c); |
|
} |
|
|
|
#endif |
|
|
|
}}} |
|
|