|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#ifndef LYRA_CODEC_SPARSE_MATMUL_NUMERICS_FLOAT16_TYPES_H_ |
|
#define LYRA_CODEC_SPARSE_MATMUL_NUMERICS_FLOAT16_TYPES_H_ |
|
|
|
#include <cstdint> |
|
#include <cstring> |
|
#include <type_traits> |
|
|
|
namespace csrblocksparse { |
|
|
|
|
|
|
|
class fp16 { |
|
public: |
|
fp16() = default; |
|
explicit fp16(float x) : val_(float_to_fp16(x)) {} |
|
explicit fp16(uint16_t x) : val_(x) {} |
|
static constexpr int kMantissaBits = 11; |
|
|
|
explicit operator float() const { return fp16_to_float(val_); } |
|
|
|
private: |
|
inline float fp16_to_float(uint16_t as_int) const { |
|
#if defined __aarch64__ |
|
float x; |
|
float* x_ptr = &x; |
|
asm volatile( |
|
"dup v0.8h, %w[as_int]\n" |
|
"fcvtl v1.4s, v0.4h\n " |
|
"st1 {v1.s}[0], [%[x_ptr]]\n" |
|
: |
|
: |
|
[x_ptr] "r"(x_ptr), |
|
[as_int] "r"(as_int) |
|
: |
|
"cc", "memory", "v0", "v1"); |
|
return x; |
|
#else |
|
unsigned int sign_bit = (as_int & 0x8000) << 16; |
|
unsigned int exponent = as_int & 0x7c00; |
|
|
|
unsigned int mantissa; |
|
if (exponent == 0) |
|
mantissa = 0; |
|
else |
|
mantissa = ((as_int & 0x7fff) << 13) + 0x38000000; |
|
mantissa |= sign_bit; |
|
|
|
float x; |
|
memcpy(&x, &mantissa, sizeof(int)); |
|
return x; |
|
#endif |
|
} |
|
|
|
inline uint16_t float_to_fp16(float x) const { |
|
#if defined __aarch64__ |
|
uint16_t as_int; |
|
uint16_t* as_int_ptr = &as_int; |
|
asm volatile( |
|
"dup v0.4s, %w[x]\n" |
|
"fcvtn v1.4h, v0.4s\n" |
|
"st1 {v1.h}[0], [%[as_int_ptr]]\n" |
|
: |
|
: |
|
[as_int_ptr] "r"(as_int_ptr), |
|
[x] "r"(x) |
|
: |
|
"cc", "memory", "v0", "v1"); |
|
return as_int; |
|
#else |
|
unsigned int x_int; |
|
memcpy(&x_int, &x, sizeof(int)); |
|
|
|
unsigned int sign_bit = (x_int & 0x80000000) >> 16; |
|
unsigned int exponent = x_int & 0x7f800000; |
|
|
|
unsigned int mantissa; |
|
if (exponent < 0x38800000) { |
|
mantissa = 0; |
|
} else if (exponent > 0x8e000000) { |
|
mantissa = 0x7bff; |
|
} else { |
|
mantissa = ((x_int & 0x7fffffff) >> 13) - 0x1c000; |
|
} |
|
|
|
mantissa |= sign_bit; |
|
|
|
return static_cast<uint16_t>(mantissa & 0xFFFF); |
|
#endif |
|
} |
|
|
|
uint16_t val_; |
|
}; |
|
|
|
|
|
|
|
class bfloat16 { |
|
public: |
|
bfloat16() = default; |
|
explicit bfloat16(float x) : val_(float_to_bfloat16(x)) {} |
|
explicit bfloat16(uint16_t x) : val_(x) {} |
|
static constexpr int kMantissaBits = 7; |
|
|
|
explicit operator float() const { return bfloat16_to_float(val_); } |
|
|
|
private: |
|
inline uint16_t float_to_bfloat16(float x) const { |
|
uint32_t as_int; |
|
std::memcpy(&as_int, &x, sizeof(float)); |
|
return as_int >> 16; |
|
} |
|
|
|
inline float bfloat16_to_float(uint32_t as_int) const { |
|
as_int <<= 16; |
|
float x; |
|
std::memcpy(&x, &as_int, sizeof(float)); |
|
return x; |
|
} |
|
|
|
uint16_t val_; |
|
}; |
|
|
|
template <typename T> |
|
struct IsCustomFloatType |
|
: std::integral_constant<bool, std::is_same<T, bfloat16>::value || |
|
std::is_same<T, fp16>::value> {}; |
|
template <typename T> |
|
struct IsAnyFloatType |
|
: std::integral_constant<bool, std::is_floating_point<T>::value || |
|
IsCustomFloatType<T>::value> {}; |
|
|
|
} |
|
|
|
#endif |
|
|