/* * This implementation is extracted from PyTorch: * Repo: github.com/pytorch/pytorch * File: torch/lib/TH/THHalf.c * Commit ID: 92481b59d31199df57420d4b14912348cc780d1d * Functions are made "static inline" for performance */ /* Copyright 1993-2014 NVIDIA Corporation. All rights reserved. */ // Host functions for converting between FP32 and FP16 formats static inline void TH_halfbits2float(unsigned short* src, float* res) { unsigned h = *src; unsigned sign = ((h >> 15) & 1); unsigned exponent = ((h >> 10) & 0x1f); unsigned mantissa = ((h & 0x3ff) << 13); if (exponent == 0x1f) { /* NaN or Inf */ mantissa = (mantissa ? (sign = 0, 0x7fffff) : 0); exponent = 0xff; } else if (!exponent) { /* Denorm or Zero */ if (mantissa) { unsigned int msb; exponent = 0x71; do { msb = (mantissa & 0x400000); mantissa <<= 1; /* normalize */ --exponent; } while (!msb); mantissa &= 0x7fffff; /* 1.mantissa is implicit */ } } else { exponent += 0x70; } *(unsigned*)res = ((sign << 31) | (exponent << 23) | mantissa); } static inline void TH_float2halfbits(float* src, unsigned short* dest) { unsigned x = *(unsigned*)src; unsigned u = (x & 0x7fffffff), remainder, shift, lsb, lsb_s1, lsb_m1; unsigned sign, exponent, mantissa; // Get rid of +NaN/-NaN case first. if (u > 0x7f800000) { *dest = 0x7fffU; return ; } sign = ((x >> 16) & 0x8000); // Get rid of +Inf/-Inf, +0/-0. if (u > 0x477fefff) { *dest = sign | 0x7c00U; return; } if (u < 0x33000001) { *dest = (sign | 0x0000); return; } exponent = ((u >> 23) & 0xff); mantissa = (u & 0x7fffff); if (exponent > 0x70) { shift = 13; exponent -= 0x70; } else { shift = 0x7e - exponent; exponent = 0; mantissa |= 0x800000; } lsb = (1 << shift); lsb_s1 = (lsb >> 1); lsb_m1 = (lsb - 1); // Round to nearest even. remainder = (mantissa & lsb_m1); mantissa >>= shift; if (remainder > lsb_s1 || (remainder == lsb_s1 && (mantissa & 0x1))) { ++mantissa; if (!(mantissa & 0x3ff)) { ++exponent; mantissa = 0; } } *dest = (sign | (exponent << 10) | mantissa); }