|
|
|
|
|
|
|
#pragma once |
|
|
|
#include <stdint.h> |
|
#include <cmath> |
|
#include <cstring> |
|
#include <limits> |
|
|
|
namespace onnxruntime_float16 { |
|
|
|
namespace detail { |
|
|
|
enum class endian { |
|
#if defined(_WIN32) |
|
little = 0, |
|
big = 1, |
|
native = little, |
|
#elif defined(__GNUC__) || defined(__clang__) |
|
little = __ORDER_LITTLE_ENDIAN__, |
|
big = __ORDER_BIG_ENDIAN__, |
|
native = __BYTE_ORDER__, |
|
#else |
|
#error onnxruntime_float16::detail::endian is not implemented in this environment. |
|
#endif |
|
}; |
|
|
|
static_assert( |
|
endian::native == endian::little || endian::native == endian::big, |
|
"Only little-endian or big-endian native byte orders are supported."); |
|
|
|
} |
|
|
|
|
|
|
|
|
|
template <class Derived> |
|
struct Float16Impl { |
|
protected: |
|
|
|
|
|
|
|
|
|
|
|
constexpr static uint16_t ToUint16Impl(float v) noexcept; |
|
|
|
|
|
|
|
|
|
|
|
float ToFloatImpl() const noexcept; |
|
|
|
|
|
|
|
|
|
|
|
uint16_t AbsImpl() const noexcept { |
|
return static_cast<uint16_t>(val & ~kSignMask); |
|
} |
|
|
|
|
|
|
|
|
|
|
|
uint16_t NegateImpl() const noexcept { |
|
return IsNaN() ? val : static_cast<uint16_t>(val ^ kSignMask); |
|
} |
|
|
|
public: |
|
|
|
static constexpr uint16_t kSignMask = 0x8000U; |
|
static constexpr uint16_t kBiasedExponentMask = 0x7C00U; |
|
static constexpr uint16_t kPositiveInfinityBits = 0x7C00U; |
|
static constexpr uint16_t kNegativeInfinityBits = 0xFC00U; |
|
static constexpr uint16_t kPositiveQNaNBits = 0x7E00U; |
|
static constexpr uint16_t kNegativeQNaNBits = 0xFE00U; |
|
static constexpr uint16_t kEpsilonBits = 0x4170U; |
|
static constexpr uint16_t kMinValueBits = 0xFBFFU; |
|
static constexpr uint16_t kMaxValueBits = 0x7BFFU; |
|
static constexpr uint16_t kOneBits = 0x3C00U; |
|
static constexpr uint16_t kMinusOneBits = 0xBC00U; |
|
|
|
uint16_t val{0}; |
|
|
|
Float16Impl() = default; |
|
|
|
|
|
|
|
|
|
|
|
bool IsNegative() const noexcept { |
|
return static_cast<int16_t>(val) < 0; |
|
} |
|
|
|
|
|
|
|
|
|
|
|
bool IsNaN() const noexcept { |
|
return AbsImpl() > kPositiveInfinityBits; |
|
} |
|
|
|
|
|
|
|
|
|
|
|
bool IsFinite() const noexcept { |
|
return AbsImpl() < kPositiveInfinityBits; |
|
} |
|
|
|
|
|
|
|
|
|
|
|
bool IsPositiveInfinity() const noexcept { |
|
return val == kPositiveInfinityBits; |
|
} |
|
|
|
|
|
|
|
|
|
|
|
bool IsNegativeInfinity() const noexcept { |
|
return val == kNegativeInfinityBits; |
|
} |
|
|
|
|
|
|
|
|
|
|
|
bool IsInfinity() const noexcept { |
|
return AbsImpl() == kPositiveInfinityBits; |
|
} |
|
|
|
|
|
|
|
|
|
|
|
bool IsNaNOrZero() const noexcept { |
|
auto abs = AbsImpl(); |
|
return (abs == 0 || abs > kPositiveInfinityBits); |
|
} |
|
|
|
|
|
|
|
|
|
|
|
bool IsNormal() const noexcept { |
|
auto abs = AbsImpl(); |
|
return (abs < kPositiveInfinityBits) |
|
&& (abs != 0) |
|
&& ((abs & kBiasedExponentMask) != 0); |
|
} |
|
|
|
|
|
|
|
|
|
|
|
bool IsSubnormal() const noexcept { |
|
auto abs = AbsImpl(); |
|
return (abs < kPositiveInfinityBits) |
|
&& (abs != 0) |
|
&& ((abs & kBiasedExponentMask) == 0); |
|
} |
|
|
|
|
|
|
|
|
|
|
|
Derived Abs() const noexcept { return Derived::FromBits(AbsImpl()); } |
|
|
|
|
|
|
|
|
|
|
|
Derived Negate() const noexcept { return Derived::FromBits(NegateImpl()); } |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
static bool AreZero(const Float16Impl& lhs, const Float16Impl& rhs) noexcept { |
|
return static_cast<uint16_t>((lhs.val | rhs.val) & ~kSignMask) == 0; |
|
} |
|
|
|
bool operator==(const Float16Impl& rhs) const noexcept { |
|
if (IsNaN() || rhs.IsNaN()) { |
|
|
|
return false; |
|
} |
|
return val == rhs.val; |
|
} |
|
|
|
bool operator!=(const Float16Impl& rhs) const noexcept { return !(*this == rhs); } |
|
|
|
bool operator<(const Float16Impl& rhs) const noexcept { |
|
if (IsNaN() || rhs.IsNaN()) { |
|
|
|
return false; |
|
} |
|
|
|
const bool left_is_negative = IsNegative(); |
|
if (left_is_negative != rhs.IsNegative()) { |
|
|
|
|
|
|
|
return left_is_negative && !AreZero(*this, rhs); |
|
} |
|
return (val != rhs.val) && ((val < rhs.val) ^ left_is_negative); |
|
} |
|
}; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
namespace detail { |
|
union float32_bits { |
|
unsigned int u; |
|
float f; |
|
}; |
|
} |
|
|
|
template <class Derived> |
|
inline constexpr uint16_t Float16Impl<Derived>::ToUint16Impl(float v) noexcept { |
|
detail::float32_bits f{}; |
|
f.f = v; |
|
|
|
constexpr detail::float32_bits f32infty = {255 << 23}; |
|
constexpr detail::float32_bits f16max = {(127 + 16) << 23}; |
|
constexpr detail::float32_bits denorm_magic = {((127 - 15) + (23 - 10) + 1) << 23}; |
|
constexpr unsigned int sign_mask = 0x80000000u; |
|
uint16_t val = static_cast<uint16_t>(0x0u); |
|
|
|
unsigned int sign = f.u & sign_mask; |
|
f.u ^= sign; |
|
|
|
|
|
|
|
|
|
|
|
|
|
if (f.u >= f16max.u) { |
|
val = (f.u > f32infty.u) ? 0x7e00 : 0x7c00; |
|
} else { |
|
if (f.u < (113 << 23)) { |
|
|
|
|
|
|
|
f.f += denorm_magic.f; |
|
|
|
|
|
val = static_cast<uint16_t>(f.u - denorm_magic.u); |
|
} else { |
|
unsigned int mant_odd = (f.u >> 13) & 1; |
|
|
|
|
|
|
|
|
|
f.u += 0xc8000fffU; |
|
|
|
f.u += mant_odd; |
|
|
|
val = static_cast<uint16_t>(f.u >> 13); |
|
} |
|
} |
|
|
|
val |= static_cast<uint16_t>(sign >> 16); |
|
return val; |
|
} |
|
|
|
template <class Derived> |
|
inline float Float16Impl<Derived>::ToFloatImpl() const noexcept { |
|
constexpr detail::float32_bits magic = {113 << 23}; |
|
constexpr unsigned int shifted_exp = 0x7c00 << 13; |
|
detail::float32_bits o{}; |
|
|
|
o.u = (val & 0x7fff) << 13; |
|
unsigned int exp = shifted_exp & o.u; |
|
o.u += (127 - 15) << 23; |
|
|
|
|
|
if (exp == shifted_exp) { |
|
o.u += (128 - 16) << 23; |
|
} else if (exp == 0) { |
|
o.u += 1 << 23; |
|
o.f -= magic.f; |
|
} |
|
|
|
|
|
|
|
#if (defined _MSC_VER) && (defined _M_ARM || defined _M_ARM64 || defined _M_ARM64EC) |
|
if (IsNegative()) { |
|
return -o.f; |
|
} |
|
#else |
|
|
|
o.u |= (val & 0x8000U) << 16U; |
|
#endif |
|
return o.f; |
|
} |
|
|
|
|
|
template <class Derived> |
|
struct BFloat16Impl { |
|
protected: |
|
|
|
|
|
|
|
|
|
|
|
static uint16_t ToUint16Impl(float v) noexcept; |
|
|
|
|
|
|
|
|
|
|
|
float ToFloatImpl() const noexcept; |
|
|
|
|
|
|
|
|
|
|
|
uint16_t AbsImpl() const noexcept { |
|
return static_cast<uint16_t>(val & ~kSignMask); |
|
} |
|
|
|
|
|
|
|
|
|
|
|
uint16_t NegateImpl() const noexcept { |
|
return IsNaN() ? val : static_cast<uint16_t>(val ^ kSignMask); |
|
} |
|
|
|
public: |
|
|
|
static constexpr uint16_t kSignMask = 0x8000U; |
|
static constexpr uint16_t kBiasedExponentMask = 0x7F80U; |
|
static constexpr uint16_t kPositiveInfinityBits = 0x7F80U; |
|
static constexpr uint16_t kNegativeInfinityBits = 0xFF80U; |
|
static constexpr uint16_t kPositiveQNaNBits = 0x7FC1U; |
|
static constexpr uint16_t kNegativeQNaNBits = 0xFFC1U; |
|
static constexpr uint16_t kSignaling_NaNBits = 0x7F80U; |
|
static constexpr uint16_t kEpsilonBits = 0x0080U; |
|
static constexpr uint16_t kMinValueBits = 0xFF7FU; |
|
static constexpr uint16_t kMaxValueBits = 0x7F7FU; |
|
static constexpr uint16_t kRoundToNearest = 0x7FFFU; |
|
static constexpr uint16_t kOneBits = 0x3F80U; |
|
static constexpr uint16_t kMinusOneBits = 0xBF80U; |
|
|
|
uint16_t val{0}; |
|
|
|
BFloat16Impl() = default; |
|
|
|
|
|
|
|
|
|
|
|
bool IsNegative() const noexcept { |
|
return static_cast<int16_t>(val) < 0; |
|
} |
|
|
|
|
|
|
|
|
|
|
|
bool IsNaN() const noexcept { |
|
return AbsImpl() > kPositiveInfinityBits; |
|
} |
|
|
|
|
|
|
|
|
|
|
|
bool IsFinite() const noexcept { |
|
return AbsImpl() < kPositiveInfinityBits; |
|
} |
|
|
|
|
|
|
|
|
|
|
|
bool IsPositiveInfinity() const noexcept { |
|
return val == kPositiveInfinityBits; |
|
} |
|
|
|
|
|
|
|
|
|
|
|
bool IsNegativeInfinity() const noexcept { |
|
return val == kNegativeInfinityBits; |
|
} |
|
|
|
|
|
|
|
|
|
|
|
bool IsInfinity() const noexcept { |
|
return AbsImpl() == kPositiveInfinityBits; |
|
} |
|
|
|
|
|
|
|
|
|
|
|
bool IsNaNOrZero() const noexcept { |
|
auto abs = AbsImpl(); |
|
return (abs == 0 || abs > kPositiveInfinityBits); |
|
} |
|
|
|
|
|
|
|
|
|
|
|
bool IsNormal() const noexcept { |
|
auto abs = AbsImpl(); |
|
return (abs < kPositiveInfinityBits) |
|
&& (abs != 0) |
|
&& ((abs & kBiasedExponentMask) != 0); |
|
} |
|
|
|
|
|
|
|
|
|
|
|
bool IsSubnormal() const noexcept { |
|
auto abs = AbsImpl(); |
|
return (abs < kPositiveInfinityBits) |
|
&& (abs != 0) |
|
&& ((abs & kBiasedExponentMask) == 0); |
|
} |
|
|
|
|
|
|
|
|
|
|
|
Derived Abs() const noexcept { return Derived::FromBits(AbsImpl()); } |
|
|
|
|
|
|
|
|
|
|
|
Derived Negate() const noexcept { return Derived::FromBits(NegateImpl()); } |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
static bool AreZero(const BFloat16Impl& lhs, const BFloat16Impl& rhs) noexcept { |
|
|
|
|
|
|
|
return static_cast<uint16_t>((lhs.val | rhs.val) & ~kSignMask) == 0; |
|
} |
|
}; |
|
|
|
template <class Derived> |
|
inline uint16_t BFloat16Impl<Derived>::ToUint16Impl(float v) noexcept { |
|
uint16_t result; |
|
if (std::isnan(v)) { |
|
result = kPositiveQNaNBits; |
|
} else { |
|
auto get_msb_half = [](float fl) { |
|
uint16_t result; |
|
#ifdef __cpp_if_constexpr |
|
if constexpr (detail::endian::native == detail::endian::little) { |
|
#else |
|
if (detail::endian::native == detail::endian::little) { |
|
#endif |
|
std::memcpy(&result, reinterpret_cast<char*>(&fl) + sizeof(uint16_t), sizeof(uint16_t)); |
|
} else { |
|
std::memcpy(&result, &fl, sizeof(uint16_t)); |
|
} |
|
return result; |
|
}; |
|
|
|
uint16_t upper_bits = get_msb_half(v); |
|
union { |
|
uint32_t U32; |
|
float F32; |
|
}; |
|
F32 = v; |
|
U32 += (upper_bits & 1) + kRoundToNearest; |
|
result = get_msb_half(F32); |
|
} |
|
return result; |
|
} |
|
|
|
template <class Derived> |
|
inline float BFloat16Impl<Derived>::ToFloatImpl() const noexcept { |
|
if (IsNaN()) { |
|
return std::numeric_limits<float>::quiet_NaN(); |
|
} |
|
float result; |
|
char* const first = reinterpret_cast<char*>(&result); |
|
char* const second = first + sizeof(uint16_t); |
|
#ifdef __cpp_if_constexpr |
|
if constexpr (detail::endian::native == detail::endian::little) { |
|
#else |
|
if (detail::endian::native == detail::endian::little) { |
|
#endif |
|
std::memset(first, 0, sizeof(uint16_t)); |
|
std::memcpy(second, &val, sizeof(uint16_t)); |
|
} else { |
|
std::memcpy(first, &val, sizeof(uint16_t)); |
|
std::memset(second, 0, sizeof(uint16_t)); |
|
} |
|
return result; |
|
} |
|
|
|
} |
|
|