|
#pragma once |
|
|
|
|
|
|
|
|
|
#include <c10/macros/Macros.h> |
|
#include <cmath> |
|
#include <cstring> |
|
|
|
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000 |
|
#include <cuda_bf16.h> |
|
#endif |
|
|
|
namespace c10 { |
|
|
|
namespace detail { |
|
inline C10_HOST_DEVICE float f32_from_bits(uint16_t src) { |
|
float res = 0; |
|
uint32_t tmp = src; |
|
tmp <<= 16; |
|
|
|
#if defined(USE_ROCM) |
|
float* tempRes; |
|
|
|
|
|
|
|
tempRes = reinterpret_cast<float*>(&tmp); |
|
res = *tempRes; |
|
#else |
|
std::memcpy(&res, &tmp, sizeof(tmp)); |
|
#endif |
|
|
|
return res; |
|
} |
|
|
|
inline C10_HOST_DEVICE uint16_t bits_from_f32(float src) { |
|
uint32_t res = 0; |
|
|
|
#if defined(USE_ROCM) |
|
|
|
|
|
uint32_t* tempRes = reinterpret_cast<uint32_t*>(&src); |
|
res = *tempRes; |
|
#else |
|
std::memcpy(&res, &src, sizeof(res)); |
|
#endif |
|
|
|
return res >> 16; |
|
} |
|
|
|
inline C10_HOST_DEVICE uint16_t round_to_nearest_even(float src) { |
|
#if defined(USE_ROCM) |
|
if (src != src) { |
|
#elif defined(_MSC_VER) |
|
if (isnan(src)) { |
|
#else |
|
if (std::isnan(src)) { |
|
#endif |
|
return UINT16_C(0x7FC0); |
|
} else { |
|
union { |
|
uint32_t U32; |
|
float F32; |
|
}; |
|
|
|
F32 = src; |
|
uint32_t rounding_bias = ((U32 >> 16) & 1) + UINT32_C(0x7FFF); |
|
return static_cast<uint16_t>((U32 + rounding_bias) >> 16); |
|
} |
|
} |
|
} |
|
|
|
struct alignas(2) BFloat16 { |
|
uint16_t x; |
|
|
|
|
|
#if defined(USE_ROCM) |
|
C10_HOST_DEVICE BFloat16() = default; |
|
#else |
|
BFloat16() = default; |
|
#endif |
|
|
|
struct from_bits_t {}; |
|
static constexpr C10_HOST_DEVICE from_bits_t from_bits() { |
|
return from_bits_t(); |
|
} |
|
|
|
constexpr C10_HOST_DEVICE BFloat16(unsigned short bits, from_bits_t) |
|
: x(bits){}; |
|
inline C10_HOST_DEVICE BFloat16(float value); |
|
inline C10_HOST_DEVICE operator float() const; |
|
|
|
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000 |
|
inline C10_HOST_DEVICE BFloat16(const __nv_bfloat16& value); |
|
explicit inline C10_HOST_DEVICE operator __nv_bfloat16() const; |
|
#endif |
|
}; |
|
|
|
} |
|
|
|
#include <c10/util/BFloat16-inl.h> |
|
|