test / src /f16-vtanh /avx-expm1minus.c.in
Androidonnxfork's picture
Upload folder using huggingface_hub
8b7c501
raw
history blame
No virus
12.7 kB
// Copyright 2023 Google LLC
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.
$assert P == H + 1 or P == H + 2
$assert not PS or (P, H) == (4, 3)
$assert DIV in ["DIV", "RCP"]
$assert SAT in ["MINMAX", "SELECT"]
$assert AVX != 2 or FMA == 3
$assert BATCH_TILE % 8 == 0
$assert BATCH_TILE >= 8
$SIMD_TILE = BATCH_TILE // 8
#include <assert.h>
#include <stddef.h>
#include <stdint.h>
#include <immintrin.h>
#include <xnnpack/common.h>
#include <xnnpack/intrinsics-polyfill.h>
#include <xnnpack/microparams.h>
#include <xnnpack/vunary.h>
$POLY_SUFFIX = "p%dh%d%s" % (P, H, "ps" if PS else "ts")
$DIV_SUFFIX = DIV.lower()
$PARAMS_STRUCT = "avx_expm1minus_rr1_" + ("p%dh%d" % (P, H))
$ISA = "avx2" if AVX == 2 else "fma3" if FMA == 3 else "f16c"
void xnn_f16_vtanh_ukernel__${ISA}_expm1minus_rr1_${POLY_SUFFIX}_${DIV_SUFFIX}_x${BATCH_TILE}(
size_t batch,
const void* input,
void* output,
const union xnn_f16_tanh_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
{
assert(batch != 0);
assert(batch % sizeof(uint16_t) == 0);
assert(input != NULL);
assert(output != NULL);
const __m128i vsign_mask = _mm_load_si128((const __m128i*) params->${PARAMS_STRUCT}.sign_mask);
const __m256 vsat_cutoff = _mm256_load_ps(params->${PARAMS_STRUCT}.sat_cutoff);
const __m256 vlog2e = _mm256_load_ps(params->${PARAMS_STRUCT}.log2e);
const __m256 vmagic_bias = _mm256_load_ps(params->${PARAMS_STRUCT}.magic_bias);
const __m256 vminus_ln2 = _mm256_load_ps(params->${PARAMS_STRUCT}.minus_ln2);
$for i in reversed(range(2, P+1)):
const __m256 vc${i} = _mm256_load_ps(params->${PARAMS_STRUCT}.c${i});
$if P != H + 1:
const __m256 vminus_one = _mm256_load_ps(params->${PARAMS_STRUCT}.minus_one);
const __m256 vtwo = _mm256_load_ps(params->${PARAMS_STRUCT}.two);
$if P == H + 1:
const __m256 vminus_one = _mm256_load_ps(params->${PARAMS_STRUCT}.minus_one);
const uint16_t* i = (const uint16_t*) input;
uint16_t* o = (uint16_t*) output;
$if BATCH_TILE > 8:
for (; batch >= ${BATCH_TILE} * sizeof(uint16_t); batch -= ${BATCH_TILE} * sizeof(uint16_t)) {
const __m128i vx0 = _mm_loadu_si128((const __m128i*) i);
$for N in range(1, SIMD_TILE):
const __m128i vx${N} = _mm_loadu_si128((const __m128i*) (i + ${N * 8}));
i += ${BATCH_TILE};
$for N in range(SIMD_TILE):
const __m128i vabsx${N} = _mm_or_si128(vx${N}, vsign_mask);
$for N in range(SIMD_TILE):
__m256 vz${N} = _mm256_cvtph_ps(vabsx${N});
const __m128i vinvsignx${N} = _mm_xor_si128(vx${N}, vabsx${N});
$for N in range(SIMD_TILE):
$if SAT == "MINMAX":
vz${N} = _mm256_max_ps(vsat_cutoff, vz${N});
$elif SAT == "SELECT":
const __m256 vm${N} = _mm256_cmp_ps(vz${N}, vsat_cutoff, _CMP_LE_OS);
$if FMA == 3:
__m256 vn${N} = _mm256_fmadd_ps(vz${N}, vlog2e, vmagic_bias);
$else:
__m256 vn${N} = _mm256_add_ps(_mm256_mul_ps(vz${N}, vlog2e), vmagic_bias);
$if AVX == 1:
$for N in range(SIMD_TILE):
const __m128 vn${N}_hi = _mm256_extractf128_ps(vn${N}, 1);
__m256 vs${N} = _mm256_castps128_ps256(_mm_castsi128_ps(_mm_slli_epi32(_mm_castps_si128(_mm256_castps256_ps128(vn${N})), 23)));
vn${N} = _mm256_sub_ps(vn${N}, vmagic_bias);
$for N in range(SIMD_TILE):
const __m128 vs${N}_hi = _mm_castsi128_ps(_mm_slli_epi32(_mm_castps_si128(vn${N}_hi), 23));
$for N in range(SIMD_TILE):
vs${N} = _mm256_insertf128_ps(vs${N}, vs${N}_hi, 1);
$else:
$for N in range(SIMD_TILE):
const __m256 vs${N} = _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_castps_si256(vn${N}), 23));
vn${N} = _mm256_sub_ps(vn${N}, vmagic_bias);
$for N in range(SIMD_TILE):
$if FMA == 3:
const __m256 vt${N} = _mm256_fmadd_ps(vn${N}, vminus_ln2, vz${N});
$else:
const __m256 vt${N} = _mm256_add_ps(_mm256_mul_ps(vn${N}, vminus_ln2), vz${N});
$if FMA == 3:
$for N in range(SIMD_TILE):
__m256 vp${N} = vc${P};
$for i in reversed(range(2, P)):
$for N in range(SIMD_TILE):
vp${N} = _mm256_fmadd_ps(vp${N}, vt${N}, vc${i});
$else:
$for N in range(SIMD_TILE):
__m256 vp${N} = _mm256_add_ps(_mm256_mul_ps(vc${P}, vt${N}), vc${P-1});
$for i in reversed(range(2, P-1)):
$for N in range(SIMD_TILE):
vp${N} = _mm256_add_ps(_mm256_mul_ps(vp${N}, vt${N}), vc${i});
$if P == H + 1:
$for N in range(SIMD_TILE):
$if FMA == 3:
vp${N} = _mm256_fmadd_ps(vp${N}, vt${N}, vtwo);
$else:
vp${N} = _mm256_add_ps(_mm256_mul_ps(vp${N}, vt${N}), vtwo);
$else:
$for N in range(SIMD_TILE):
vp${N} = _mm256_mul_ps(vp${N}, vt${N});
$for N in range(SIMD_TILE):
const __m256 vts${N} = _mm256_mul_ps(vt${N}, vs${N});
const __m256 vsmo${N} = _mm256_add_ps(vs${N}, vminus_one);
$if P == H + 1:
$for N in range(SIMD_TILE):
$if FMA == 3:
const __m256 vemo${N} = _mm256_fmadd_ps(vp${N}, vts${N}, vsmo${N});
$else:
const __m256 vemo${N} = _mm256_add_ps(_mm256_mul_ps(vp${N}, vts${N}), vsmo${N});
$else:
$if FMA == 3:
$for N in range(SIMD_TILE):
vp${N} = _mm256_fmadd_ps(vp${N}, vts${N}, vts${N});
$for N in range(SIMD_TILE):
const __m256 vemo${N} = _mm256_fmadd_ps(vp${N}, vtwo${N}, vsmo${N});
$else:
$for N in range(SIMD_TILE):
vp${N} = _mm256_add_ps(_mm256_mul_ps(vp${N}, vts${N}), vts${N});
$for N in range(SIMD_TILE):
const __m256 vemo${N} = _mm256_add_ps(_mm256_mul_ps(vp${N}, vtwo), vsmo${N});
$for N in range(SIMD_TILE):
const __m256 vepo${N} = _mm256_add_ps(vemo${N}, vtwo);
$if DIV == "DIV":
$for N in range(SIMD_TILE):
__m256 vy${N} = _mm256_div_ps(vemo${N}, vepo${N});
$else:
$for N in range(SIMD_TILE):
__m256 vrepo${N} = _mm256_rcp_ps(vepo${N});
$for N in range(SIMD_TILE):
__m256 vy${N} = _mm256_mul_ps(vemo${N}, vrepo${N});
$if SAT == "SELECT":
$for N in range(SIMD_TILE):
vy${N} = _mm256_blendv_ps(vy${N}, vminus_one, vm${N});
$for N in range(SIMD_TILE):
__m128i vh${N} = _mm256_cvtps_ph(vy${N}, _MM_FROUND_TO_NEAREST_INT);
$for N in range(SIMD_TILE):
vh${N} = _mm_xor_si128(vh${N}, vinvsignx${N});
_mm_storeu_si128((__m128i*) o, vh0);
$for N in range(1, SIMD_TILE):
_mm_storeu_si128((__m128i*) (o + ${N * 8}), vh${N});
o += ${BATCH_TILE};
}
for (; batch >= 8 * sizeof(uint16_t); batch -= 8 * sizeof(uint16_t)) {
const __m128i vx = _mm_loadu_si128((const __m128i*) i);
i += 8;
const __m128i vabsx = _mm_or_si128(vx, vsign_mask);
__m256 vz = _mm256_cvtph_ps(vabsx);
const __m128i vinvsignx = _mm_xor_si128(vx, vabsx);
$if SAT == "MINMAX":
vz = _mm256_max_ps(vsat_cutoff, vz);
$elif SAT == "SELECT":
const __m256 vm = _mm256_cmp_ps(vz, vsat_cutoff, _CMP_LE_OS);
$if FMA == 3:
__m256 vn = _mm256_fmadd_ps(vz, vlog2e, vmagic_bias);
$else:
__m256 vn = _mm256_add_ps(_mm256_mul_ps(vz, vlog2e), vmagic_bias);
$if AVX == 1:
const __m128 vn_hi = _mm256_extractf128_ps(vn, 1);
__m256 vs = _mm256_castps128_ps256(_mm_castsi128_ps(_mm_slli_epi32(_mm_castps_si128(_mm256_castps256_ps128(vn)), 23)));
const __m128 vs_hi = _mm_castsi128_ps(_mm_slli_epi32(_mm_castps_si128(vn_hi), 23));
vs = _mm256_insertf128_ps(vs, vs_hi, 1);
$else:
const __m256 vs = _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_castps_si256(vn), 23));
vn = _mm256_sub_ps(vn, vmagic_bias);
$if FMA == 3:
const __m256 vt = _mm256_fmadd_ps(vn, vminus_ln2, vz);
$else:
const __m256 vt = _mm256_add_ps(_mm256_mul_ps(vn, vminus_ln2), vz);
$if FMA == 3:
__m256 vp = vc${P};
$for i in reversed(range(2, P)):
vp = _mm256_fmadd_ps(vp, vt, vc${i});
$else:
__m256 vp = _mm256_add_ps(_mm256_mul_ps(vc${P}, vt), vc${P-1});
$for i in reversed(range(2, P-1)):
vp = _mm256_add_ps(_mm256_mul_ps(vp, vt), vc${i});
$if P == H + 1:
$if FMA == 3:
vp = _mm256_fmadd_ps(vp, vt, vtwo);
$else:
vp = _mm256_add_ps(_mm256_mul_ps(vp, vt), vtwo);
$else:
vp = _mm256_mul_ps(vp, vt);
const __m256 vts = _mm256_mul_ps(vt, vs);
const __m256 vsmo = _mm256_add_ps(vs, vminus_one);
$if P == H + 1:
$if FMA == 3:
const __m256 vemo = _mm256_fmadd_ps(vp, vts, vsmo);
$else:
const __m256 vemo = _mm256_add_ps(_mm256_mul_ps(vp, vts), vsmo);
$else:
$if FMA == 3:
vp = _mm256_fmadd_ps(vp, vts, vts);
const __m256 vemo = _mm256_fmadd_ps(vp, vtwo, vsmo);
$else:
vp = _mm256_add_ps(_mm256_mul_ps(vp, vts), vts);
const __m256 vemo = _mm256_add_ps(_mm256_mul_ps(vp, vtwo), vsmo);
const __m256 vepo = _mm256_add_ps(vemo, vtwo);
$if DIV == "DIV":
__m256 vy = _mm256_div_ps(vemo, vepo);
$else:
__m256 vrepo = _mm256_rcp_ps(vepo);
__m256 vy = _mm256_mul_ps(vemo, vrepo);
$if SAT == "SELECT":
vy = _mm256_blendv_ps(vy, vminus_one, vm);
__m128i vh = _mm256_cvtps_ph(vy, _MM_FROUND_TO_NEAREST_INT);
vh = _mm_xor_si128(vh, vinvsignx);
_mm_storeu_si128((__m128i*) o, vh);
o += 8;
}
if (batch != 0) {
const __m128i vx = _mm_loadu_si128((const __m128i*) i);
const __m128i vabsx = _mm_or_si128(vx, vsign_mask);
__m256 vz = _mm256_cvtph_ps(vabsx);
const __m128i vinvsignx = _mm_xor_si128(vx, vabsx);
$if SAT == "MINMAX":
vz = _mm256_max_ps(vsat_cutoff, vz);
$elif SAT == "SELECT":
const __m256 vm = _mm256_cmp_ps(vz, vsat_cutoff, _CMP_LE_OS);
$if FMA == 3:
__m256 vn = _mm256_fmadd_ps(vz, vlog2e, vmagic_bias);
$else:
__m256 vn = _mm256_add_ps(_mm256_mul_ps(vz, vlog2e), vmagic_bias);
$if AVX == 1:
const __m128 vn_hi = _mm256_extractf128_ps(vn, 1);
__m256 vs = _mm256_castps128_ps256(_mm_castsi128_ps(_mm_slli_epi32(_mm_castps_si128(_mm256_castps256_ps128(vn)), 23)));
const __m128 vs_hi = _mm_castsi128_ps(_mm_slli_epi32(_mm_castps_si128(vn_hi), 23));
vs = _mm256_insertf128_ps(vs, vs_hi, 1);
$else:
const __m256 vs = _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_castps_si256(vn), 23));
vn = _mm256_sub_ps(vn, vmagic_bias);
$if FMA == 3:
const __m256 vt = _mm256_fmadd_ps(vn, vminus_ln2, vz);
$else:
const __m256 vt = _mm256_add_ps(_mm256_mul_ps(vn, vminus_ln2), vz);
$if FMA == 3:
__m256 vp = vc${P};
$for i in reversed(range(2, P)):
vp = _mm256_fmadd_ps(vp, vt, vc${i});
$else:
__m256 vp = _mm256_add_ps(_mm256_mul_ps(vc${P}, vt), vc${P-1});
$for i in reversed(range(2, P-1)):
vp = _mm256_add_ps(_mm256_mul_ps(vp, vt), vc${i});
$if P == H + 1:
$if FMA == 3:
vp = _mm256_fmadd_ps(vp, vt, vtwo);
$else:
vp = _mm256_add_ps(_mm256_mul_ps(vp, vt), vtwo);
$else:
vp = _mm256_mul_ps(vp, vt);
const __m256 vts = _mm256_mul_ps(vt, vs);
const __m256 vsmo = _mm256_add_ps(vs, vminus_one);
$if P == H + 1:
$if FMA == 3:
const __m256 vemo = _mm256_fmadd_ps(vp, vts, vsmo);
$else:
const __m256 vemo = _mm256_add_ps(_mm256_mul_ps(vp, vts), vsmo);
$else:
$if FMA == 3:
vp = _mm256_fmadd_ps(vp, vts, vts);
const __m256 vemo = _mm256_fmadd_ps(vp, vtwo, vsmo);
$else:
vp = _mm256_add_ps(_mm256_mul_ps(vp, vts), vts);
const __m256 vemo = _mm256_add_ps(_mm256_mul_ps(vp, vtwo), vsmo);
const __m256 vepo = _mm256_add_ps(vemo, vtwo);
$if DIV == "DIV":
__m256 vy = _mm256_div_ps(vemo, vepo);
$else:
__m256 vrepo = _mm256_rcp_ps(vepo);
__m256 vy = _mm256_mul_ps(vemo, vrepo);
$if SAT == "SELECT":
vy = _mm256_blendv_ps(vy, vminus_one, vm);
__m128i vh = _mm256_cvtps_ph(vy, _MM_FROUND_TO_NEAREST_INT);
vh = _mm_xor_si128(vh, vinvsignx);
if (batch & (4 * sizeof(uint16_t))) {
_mm_storel_epi64((__m128i*) o, vh);
vh = _mm_unpackhi_epi64(vh, vh);
o += 4;
}
if (batch & (2 * sizeof(uint16_t))) {
_mm_storeu_si32(o, vh);
vh = _mm_srli_epi64(vh, 32);
o += 2;
}
if (batch & (1 * sizeof(uint16_t))) {
*o = (uint16_t) _mm_extract_epi16(vh, 0);
}
}
}