Spaces:
Sleeping
Sleeping
| // Copyright 2021 Google LLC | |
| // | |
| // Licensed under the Apache License, Version 2.0 (the "License"); | |
| // you may not use this file except in compliance with the License. | |
| // You may obtain a copy of the License at | |
| // | |
| // http://www.apache.org/licenses/LICENSE-2.0 | |
| // | |
| // Unless required by applicable law or agreed to in writing, software | |
| // distributed under the License is distributed on an "AS IS" BASIS, | |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| // See the License for the specific language governing permissions and | |
| // limitations under the License. | |
| namespace csrblocksparse { | |
| // Maximum desired precision of the output. | |
| static constexpr int kMaxMantissaBits = 14; | |
| // Returns (and builds if not done yet) a static data table that implements | |
| // tanh on fixed32 input, returning another fixed32 with the given number of | |
| // mantissa bits (which is assumed to be less than the input mantissa bits). | |
| // NOTE that this function is intended to be used only with fixed16 outputs that | |
| // are sign-extended to 32 bits for convenience, and will return a nullptr | |
| // if asked for more than |kMaxMantissaBits| of precision in the output table. | |
| const int32_t* TanhTable(int num_mantissa_bits_out) { | |
| if (num_mantissa_bits_out > kMaxMantissaBits) return nullptr; | |
| // Static data dynamically created and never destructed. | |
| static const int32_t* tanh_luts[kMaxMantissaBits]; | |
| if (tanh_luts[num_mantissa_bits_out - 1] == nullptr) { | |
| // Total bits is number each side of the binary point. | |
| int tanh_lut_bits = num_mantissa_bits_out + kNumTanhExpBits; | |
| // Offset is the number of negative numbers represented. | |
| int tanh_offset = 1 << tanh_lut_bits; | |
| // Size is double the offset plus one more for zero. | |
| int tanh_size = tanh_offset * 2 + 1; | |
| // Conversion between int and float. | |
| float float_factor = static_cast<float>(1 << num_mantissa_bits_out); | |
| int* tanh_lut = new int[tanh_size]; | |
| // Initialize the table. | |
| for (int i = 0; i < tanh_size; ++i) { | |
| float x = (i - tanh_offset) / float_factor; | |
| tanh_lut[i] = static_cast<int>(std::round(tanhf(x) * float_factor)); | |
| } | |
| tanh_luts[num_mantissa_bits_out - 1] = tanh_lut; | |
| } | |
| return tanh_luts[num_mantissa_bits_out - 1]; | |
| } | |
| // As TanhTable, but for Sigmoid. | |
| const int32_t* SigmoidTable(int num_mantissa_bits_out) { | |
| if (num_mantissa_bits_out > kMaxMantissaBits) return nullptr; | |
| // Static data dynamically created and never destructed. | |
| static const int32_t* sigmoid_luts[kMaxMantissaBits]; | |
| if (sigmoid_luts[num_mantissa_bits_out - 1] == nullptr) { | |
| // Total bits is number each side of the binary point minus one for the fact | |
| // that the gradient never exceeds 1/4. (Could probably use -2.) | |
| int sigmoid_lut_bits = | |
| num_mantissa_bits_out + kNumSigmoidExpBits - kNumExtraSigmoidShiftBits; | |
| // Offset is the number of negative numbers represented. | |
| int sigmoid_offset = 1 << sigmoid_lut_bits; | |
| // Size is double the offset plus one more for zero. | |
| int sigmoid_size = sigmoid_offset * 2 + 1; | |
| // Conversion between int and float. | |
| float float_factor = static_cast<float>(1 << num_mantissa_bits_out); | |
| int* sigmoid_lut = new int[sigmoid_size]; | |
| // Initialize the table. | |
| for (int i = 0; i < sigmoid_size; ++i) { | |
| constexpr int kSigmoidFactor = 1 << kNumExtraSigmoidShiftBits; | |
| float x = ((i - sigmoid_offset) * kSigmoidFactor) / float_factor; | |
| float sigmoid = 1.0f / (1.0f + expf(-x)); | |
| sigmoid_lut[i] = static_cast<int>(std::round(sigmoid * float_factor)); | |
| } | |
| sigmoid_luts[num_mantissa_bits_out - 1] = sigmoid_lut; | |
| } | |
| return sigmoid_luts[num_mantissa_bits_out - 1]; | |
| } | |
| } // namespace csrblocksparse | |